边写代码边学习之LSTM

1.  什么是LSTM

长短期记忆网络 LSTM(long short-term memory)是 RNN 的一种变体,其核心概念在于细胞状态以及“门”结构。细胞状态相当于信息传输的路径,让信息能在序列连中传递下去。你可以将其看作网络的“记忆”。理论上讲,细胞状态能够将序列处理过程中的相关信息一直传递下去。因此,即使是较早时间步长的信息也能携带到较后时间步长的细胞中来,这克服了短时记忆的影响。信息的添加和移除我们通过“门”结构来实现,“门”结构在训练过程中会去学习该保存或遗忘哪些信息。
 

在这里插入图片描述

 

2. 实验代码

2.1. 搭建一个只有一层RNN和Dense网络的模型。

2.2. 验证LSTM里的逻辑

 假设我的输入数据是x = [1,0], 

kernel = [[[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],

              [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],]]

recurrent_kernel = [[1, 0, 0, 1, 2,1,0,1,2,0,1,0],

                              [1, 1, 0, 0, 2,1,0,1,2,2,0,0],

                              [1, 0, 1, 2, 0,1,0,1,1,0,1,0]]

biase = [3, 1, 0, 1, 1,0,0,1,0,2,0.0,0]

通过下面手算,h的结果是[0, 4,1], c 的结果是[0,4,1].  注意无激活函数。

代码验证上面的结果


def change_weight():
    # Create a simple Dense layer
    lstm_layer = LSTM(units=3, input_shape=(3, 2), activation=None, recurrent_activation=None, return_sequences=True,
                      return_state= True)

    # Simulate input data (batch size of 1 for demonstration)
    input_data = np.array([
                [[1.0, 2], [2, 3], [3, 4]],
                [[5, 6], [6, 7], [7, 8]],
                [[9, 10], [10, 11], [11, 12]]
        ])

    # Pass the input data through the layer to initialize the weights and biases
    lstm_layer(input_data)

    kernel, recurrent_kernel, biases = lstm_layer.get_weights()

    # Print the initial weights and biases
    print("recurrent_kernel:", recurrent_kernel, recurrent_kernel.shape ) # (3,3)
    print('kernal:',kernel, kernel.shape) #(2,3)
    print('biase: ',biases , biases.shape) # (3)


    kernel = np.array([[2, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0],
                       [1, 1, 0, 1, 1, 0, 0, 1, 1 ,0, 0, 0],])

    recurrent_kernel = np.array([[1, 0, 0, 1, 2,1,0,1,2,0,1,0],
                                 [1, 1, 0, 0, 2,1,0,1,2,2,0,0],
                                 [1, 0, 1, 2, 0,1,0,1,1,0,1,0]])

    biases = np.array([3, 1, 0, 1, 1,0,0,1,0,2,0.0,0])

    lstm_layer.set_weights([kernel, recurrent_kernel, biases])
    print(lstm_layer.get_weights())

    # test_data = np.array([
    #     [[1.0, 3], [1, 1], [2, 3]]
    # ])

    test_data = np.array([
        [[1,0.0]]
    ])

    output, memory_state, carry_state  = lstm_layer(test_data)

    print(output)
    print(memory_state)
    print(carry_state)
if __name__ == '__main__':
    change_weight()

执行结果:

recurrent_kernel: [[-0.36744034 -0.11181469 -0.10642298  0.5450207  -0.30208975  0.5405432
   0.09643812 -0.14983998  0.1859854   0.2336958  -0.16187981  0.11621032]
 [ 0.07727922 -0.226477    0.1491096  -0.03933501  0.31236103 -0.12963092
   0.10522162 -0.4815724  -0.2093935   0.34740582 -0.60979587 -0.15877807]
 [ 0.15371156  0.01244636 -0.09840634 -0.32093546  0.06523462  0.18934932
   0.38859126 -0.3261706  -0.05138849  0.42713478  0.49390993  0.37013963]] (3, 12)
kernal: [[-0.47606698 -0.43589187 -0.5371355  -0.07337284  0.30526626 -0.18241835
  -0.03675252  0.2873094   0.33218485  0.24838251  0.17765659  0.4312396 ]
 [ 0.4007727   0.41280174  0.40750778 -0.6245315   0.6382301   0.42889225
   0.11961156 -0.6021105  -0.43556038  0.39798307  0.6390712   0.16719025]] (2, 12)
biase:  [0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0.] (12,)
[array([[2., 1., 1., 0., 0., 0., 0., 1., 1., 0., 1., 0.],
       [1., 1., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0.]], dtype=float32), array([[1., 0., 0., 1., 2., 1., 0., 1., 2., 0., 1., 0.],
       [1., 1., 0., 0., 2., 1., 0., 1., 2., 2., 0., 0.],
       [1., 0., 1., 2., 0., 1., 0., 1., 1., 0., 1., 0.]], dtype=float32), array([3., 1., 0., 1., 1., 0., 0., 1., 0., 2., 0., 0.], dtype=float32)]
tf.Tensor([[[0. 4. 0.]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[0. 4. 0.]], shape=(1, 3), dtype=float32)
tf.Tensor([[0. 4. 1.]], shape=(1, 3), dtype=float32)

可以看出h=[0,4,0], c=[0,4,1]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/62984.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Stable Diffusion - Candy Land (糖果世界) LoRA 提示词配置与效果展示

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/132145248 糖果世界 (Candy Land) 是一个充满甜蜜和奇幻的地方,由各种各样的糖果和巧克力构成。在糖果世界,可以看到&…

Qt、C/C++环境中内嵌LUA脚本、实现LUA函数的调用执行

Qt、C/C环境中内嵌LUA脚本、实现LUA函数的调用执行 Chapter1. Qt、C/C环境中内嵌LUA脚本、实现LUA函数的调用执行1、LUA简介2、LUA脚本的解释器和编译器3、C环境中内嵌LUA执行LUA函数调用4、Qt内嵌LUA执行LUA函数调用5、运行结果6、内嵌LUA脚本在实际项目中的案例应用 Chapter1…

手机变电脑2023之虚拟电脑droidvm

手机这么大的内存,装个app来模拟linux,还是没问题的。 app 装好后,手指点几下确定按钮,等几分钟就能把linux桌面环境安装好。 不需要敲指令, 不需要对手机刷机, 不需要特殊权限, 不需要找驱…

opencv-32 图像平滑处理-高斯滤波cv2.GaussianBlur()

在进行均值滤波和方框滤波时,其邻域内每个像素的权重是相等的。在高斯滤波中,会将中心点的权重值加大,远离中心点的权重值减小,在此基础上计算邻域内各个像素值不同权重 的和。 基本原理 在高斯滤波中,卷积核中的值不…

第一篇:一文看懂 Vue.js 3.0 的优化

我们的课程是要解读 Vue.js 框架的源码,所以在进入课程之前我们先来了解一下 Vue.js 框架演进的过程,也就是 Vue.js 3.0 主要做了哪些优化。 Vue.js 从 1.x 到 2.0 版本,最大的升级就是引入了虚拟 DOM 的概念,它为后续做服务端渲…

Scala按天写入日志文件

如果希望把每天出错的信息写入日志文件,每天新建一个文件。 package test.scala import java.io.{File, FileWriter} import java.text.SimpleDateFormat import java.util.{Calendar, Date} import scala.concurrent.ExecutionContext.Implicits.global import sc…

CS录屏教程,录制游戏需要注意哪些方面?

​最近有个CS手游的玩家小伙伴咨询想要做一些游戏视频录制,但是不知道有哪些好用的工具来使用,对于游戏录制我们其实是需要注意一些事项的,想要观众的观感上比较好就需要把握好视频的帧率等问题,下面我们就来看看录制方法和需要注…

最小二乘问题和非线性优化

最小二乘问题和非线性优化 0.引言1.最小二乘问题2.迭代下降法3.最速下降法4.牛顿法5.阻尼法6.高斯牛顿(GN)法7.莱文贝格马夸特(LM)法8.鲁棒核函数 0.引言 转载自此处,修正了一点小错误。 1.最小二乘问题 在求解 SLAM 中的最优状态估计问题时,我们一般…

RISC-V - 小记

文章目录 关于 RISC-V安装 关于 RISC-V RISC : Reduced Instruction Set Computing RISC-V(“RISC five”)的目标是成为一个通用的指令集架构(ISA) 官网:https://riscv.orggithub : https://github.com/riscv 教程 [完结] 循序渐进,学习开发一个RISC-…

【stm32】初识stm32—stm32环境的搭建

文章目录 🛸stm32资料分享🍔stm32是什么🎄具体过程🏳️‍🌈安装驱动🎈1🎈2 🏳️‍🌈建立Start文件夹 🛸stm32资料分享 我用夸克网盘分享了「STM32入门教程资料…

Maven入职学习

一、什么是Maven? 概念: Maven是一种框架。它可以用作依赖管理工具、构建工具。 它可以管理jar包的规模、jar包的来源、jar包之间的依赖关系。 它的用途就是管理规模庞大的jar包,脱离IDE环境执行构建操作。 具体使用: 工作机…

yolov5代码解读之​detect.py文件【超详细的好吗!点进来看阿很用心的!】

yolov5的代码一直在更新,所以你们代码有些部分可能不太一样,但大差不差。 先给大家看一下项目结构:(最好有这个项目,且跑通过) detect.py文件:它可以预测视频、图片文件夹、网络流等等。 如何…

Django实现音乐网站 ⑺

使用Python Django框架制作一个音乐网站, 本篇主要是后台对歌手原有实现功能的基础上进行优化处理。 目录 新增编辑 表字段名称修改 隐藏单曲、专辑数 姓名首字母 安装xpinyin 获取姓名首字母 重写保存方法 列表显示 图片显示处理 引入函数 路径改为显示…

EP4CE6E22C8 FPGA最小系统电路原理图+PCB源文件

资料下载地址:EP4CE6E22C8 FPGA最小系统电路原理图PCB源文件 一、原理图 二、PCB

HTTP——八、确认访问用户身份的认证

HTTP 一、何为认证二、BASIC认证BASIC认证的认证步骤 三、DIGEST认证DIGEST认证的认证步骤 四、SSL客户端认证1、SSL 客户端认证的认证步骤2、SSL 客户端认证采用双因素认证3、SSL 客户端认证必要的费用 五、基于表单认证1、认证多半为基于表单认证2、Session 管理及 Cookie 应…

mysql二进制方式升级8.0.34

一、概述 mysql8.0.33 存在如下高危漏洞&#xff0c;需要通过升级版本修复漏洞 Oracle MySQL Cluster 安全漏洞(CVE-2023-0361) mysql/8.0.33 Apache Skywalking <8.3 SQL注入漏洞 二、查看mysql版本及安装包信息 [rootlocalhost mysql]# mysql -V mysql Ver 8.0.33 fo…

class version 61 java version 17.0.4

class version (javap -verbose xxxx.class)_spencer_tseng的博客-CSDN博客

Demystifying Prompts in Language Models via Perplexity Estimation

Demystifying Prompts in Language Models via Perplexity Estimation 原文链接 Gonen H, Iyer S, Blevins T, et al. Demystifying prompts in language models via perplexity estimation[J]. arXiv preprint arXiv:2212.04037, 2022. 简单来说就是作者通过在不同LLM和不同…

程序框架-事件中心模块-观察者模式

一、观察者模式 1.1 观察者模式定义 意图&#xff1a; 定义对象间的一种一对多的依赖关系&#xff0c;当一个对象的状态发生改变是&#xff0c;所有依赖于它的对象都能得到通知并自动更新。 适用性&#xff1a; 当一个对象状态的改变需要改变其他对象&#xff0c; 或实际对…

Python自动化实战之使用Pytest进行API测试详解

概要 每次手动测试API都需要重复输入相同的数据&#xff0c;而且还需要跑多个测试用例&#xff0c;十分繁琐和无聊。那么&#xff0c;有没有一种方法可以让你更高效地测试API呢&#xff1f;Pytest自动化测试&#xff01;今天&#xff0c;小编将向你介绍如何使用Pytest进行API自…