机器学习(10.7-10.13)(Pytorch LSTM和LSTMP的原理及其手写复现)

文章目录

    • 摘要
    • Abstract
    • 1 LSTM
    • 1.1 使用Pytorch LSTM
      • 1.1.1 LSTM API代码实现
      • 1.1.2 LSTMP代码实现
    • 1.2 手写一个lstm_forward函数 实现单向LSTM的计算原理
    • 1.3 手写一个lstmp_forward函数 实现单向LSTMP的计算原理
    • 总结

摘要

LSTM是RNN的一个优秀的变种模型,继承了大部分RNN模型的特性;在本次学习中,展示了LSTM的手动推导过程,用代码逐行模拟实现LSTM的运算过程,并与Pytorch API输出的结验证是否保持一致。

Abstract

LSTM is an excellent variant model of RNN, inheriting most of the characteristics of RNN model. In this study, the manual derivation process of LSTM is demonstrated, and the operation process of LSTM is simulated line by line with the code, and whether it is consistent with the junction output of Pytorch API is verified.

1 LSTM

1.1 使用Pytorch LSTM

实例化LSTM类,需要传递的参数:

input_size:输入数据的特征维度
hidden_size:隐含状态 h t h_t ht的大小
num_layers:默认值为1,大于1,表示多个RNN堆叠起来
batch_first:默认是False;若为True,输入输出格式为:(batch, seq, feature) ;若为False:输入输出格式为: (seq, batch, feature);
bidirectional:默认为False;若为True,则是双向RNN,同时输出长度为2*hidden_size
proj_size:若不为None,将使用具有相应大小投影的 LSTM

函数输入值:(input,(h_0,c_0))

  1. input:当batch_first=True 输入格式为(N,L, H i n H_{in} Hin);当batch_first=False 输入格式为(L,N, H i n H_{in} Hin);
  2. h_0:默认输入值为0;输入格式为:(D*num_layers,N, H o u t H_{out} Hout
  3. c_0:默认输入值为0,输入格式为:(D*num_layers,N, H c e l l H_{cell} Hcell

其中:

N = batch size
L = sequence length
D = 2 if bidirectional=True otherwise 1
H i n H_{in} Hin = input_size
H o u t H_{out} Hout = hidden_size
H c e l l H_{cell} Hcell = proj_size if proj_size>0 否则 hidden_size

函数输出值:(output,(h_n, c_n))

  1. output:当batch_first=True 输出格式为(N,L,D* H o u t H_{out} Hout);当batch_first=False 输出格式为(L,N,D* H o u t H_{out} Hout);
  2. h_n:输出格式为:(D*num_layers,N, H o u t H_{out} Hout
  3. c_n:输出格式为:(D*num_layers,N, H c e l l H_{cell} Hcell

1.1.1 LSTM API代码实现

首先实例化一些参数:

# 定义常量
bs, T, input_size, hidden_size = 2, 3, 4, 5

#输入序列
input = torch.randn(bs, T, input_size)
c0 = torch.randn(bs, hidden_size)
h0 = torch.randn(bs, hidden_size)

调用Pytorch中的LSTM API:
并查看返回的结果及其形状

# 调用官方的LSTM API
lstm_layer = nn.LSTM(input_size, hidden_size, batch_first=True)
output, (hn, cn) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))

print("----官方LSTM的输出----")
print(output)
print(output.shape) # torch.Size([2, 3, 5])---[bs, T, hidden_size]
print(hn)
print(hn.shape) # torch.Size([1, 2, 5])---[num_layers,bs,hidden_size]
print(cn)
print(cn.shape) # torch.Size([1, 2, 5])---[num_layers,bs,hidden_size]

输出为:

----官方LSTM的输出----
tensor([[[ 0.3045, -0.0770,  0.0110, -0.1592, -0.2298],
         [ 0.1142, -0.1522, -0.0975, -0.1512,  0.2834],
         [ 0.1298,  0.0090,  0.0183, -0.2742,  0.0428]],

        [[ 0.1908, -0.1505, -0.0659,  0.0520,  0.0340],
         [ 0.2583, -0.0599, -0.0442, -0.1428, -0.0102],
         [ 0.1604,  0.0004, -0.1105, -0.2062,  0.0760]]],
       grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 5])
tensor([[[ 0.1298,  0.0090,  0.0183, -0.2742,  0.0428],
         [ 0.1604,  0.0004, -0.1105, -0.2062,  0.0760]]],
       grad_fn=<StackBackward0>)
torch.Size([1, 2, 5])
tensor([[[ 0.3637,  0.0263,  0.0326, -0.5326,  0.0537],
         [ 0.4828,  0.0008, -0.2197, -0.3999,  0.1016]]],
       grad_fn=<StackBackward0>)
torch.Size([1, 2, 5])

获取 pytorch LSTM内置参数的形状

for k, v in lstm_layer.named_parameters():
    print(k, v)

在这里插入图片描述

其中:

weight_ih_l0维度为:[4*hidden_size,input_size]
weight_hh_l0维度为:[4*hidden_size,hidden_size]
bias_ih_l0维度为:[4*hidden_size]
bias_hh_l0维度为:[4*hidden_size]

1.1.2 LSTMP代码实现

如果指定proj_size >0 ,将会使用带投影的LSTM,以下列方式更改LSTM单元。
首先, h t h_t ht的维度将从hidden_size更改为:proj_size
其次,每一层的输出隐藏状态将乘以一个可学习的投影矩阵, h t = w h r h t h_t=w_{hr}h_t ht=whrht

因此,LSTMP 网络的输出也将具有不同的形状。

首先实例化一些参数:

bs, T, input_size, hidden_size = 2, 3, 4, 5
proj_size = 3

input = torch.randn(bs, T, input_size)
h0 = torch.randn(bs, proj_size)
c0 = torch.randn(bs, hidden_size)

调用PyTorch中的 LSTMP API:
并查看返回的结果及其形状

# 使用官方的LSTMP
layer_LSTMP = nn.LSTM(input_size, hidden_size, batch_first=True, proj_size=proj_size)
output, (h_n, c_n) = layer_LSTMP(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output)
print(output.shape) # torch.Size([2, 3, 3])---[bs,T,proj_size]
print(h_n)
print(h_n.shape)    # torch.Size([1, 2, 3])---[num_layers,bs,proj_size]
print(c_n)
print(c_n.shape)    # torch.Size([1, 2, 5])---[num_layers,bs,hidden_size]

输出为:

tensor([[[-0.1677, -0.0150, -0.0289],
         [-0.1369, -0.0738, -0.0311],
         [ 0.0168, -0.0184,  0.0303]],

        [[-0.0935, -0.0204, -0.1073],
         [-0.0137, -0.0056, -0.0085],
         [-0.0654, -0.0376,  0.0562]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 3])
tensor([[[ 0.0168, -0.0184,  0.0303],
         [-0.0654, -0.0376,  0.0562]]], grad_fn=<StackBackward0>)
torch.Size([1, 2, 3])
tensor([[[ 0.1691, -0.0772,  0.0209,  0.1153,  0.2290],
         [ 0.4153, -0.0653, -0.0928,  0.3042,  0.2146]]],
       grad_fn=<StackBackward0>)
torch.Size([1, 2, 5])

获取 pytorch LSTMP内置参数

for k, v in layer_LSTMP.named_parameters():
    print(k, v.shape)

在这里插入图片描述

其中:

weight_ih_l0维度为:[4*hidden_size,input_size]
weight_hh_l0维度为:[4*hidden_size,Proj_size]
bias_ih_l0维度为:[4*hidden_size]
bias_hh_l0维度为:[4*hidden_size]
weight_hr_l0维度为:[Proj_size,hidden_size]

1.2 手写一个lstm_forward函数 实现单向LSTM的计算原理

这里先将lstm_forward函数中的每个参数的维度写出来:

在这里插入图片描述

def lstm_forward(input, w_ih, w_hh, b_ih, b_hh, initial_states):
    h0, c0 = initial_states # 初始状态
    bs, T, input_size = input.shape
    hidden_size = w_ih.shape[0]//4

    prev_h = h0
    prev_c = c0
    # input 的维度为:[batch_size, T, input_size]
    # 其中 w_ih和w_hh 的维度为:[4*hidden_size, input_size],[4*hidden_size, hidden_size]
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # 这样w_ih的变为:[batch_size, 4*hidden_size, input_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # 这样w_hh的变为:[batch_size, 4*hidden_size, hidden_size]

    output_size = hidden_size
    print(output_size)    # 5
    # 输出序列
    output = torch.zeros(bs, T, output_size)

    for t in range(T):
        # 获取当前时刻的x
        x = input[:, t, :]  #这时候x的维度为:[bs, input_size]
        # 进行矩阵运算,先将x的维度扩为:[bs, input_size, 1]
        # 矩阵运算之后的维度为:[bs, 4*hidden_size, 1]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))
        w_times_x = w_times_x.squeeze(-1)    # 此时维度为:[bs, 4*hidden_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))    # [bs, 4*hidden_size,1]
        w_times_h_prev = w_times_h_prev.squeeze(-1) # 维度为:[bs, 4*hidden_size]

        # 计算输入门(i),遗忘门(f),cell门(g),输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :hidden_size] + w_times_h_prev[:, :hidden_size] + b_ih[:hidden_size] + b_hh[:hidden_size])
        f_t = torch.sigmoid(w_times_x[:, hidden_size:2*hidden_size] + w_times_h_prev[:, hidden_size:2*hidden_size] +
                            b_ih[hidden_size:2*hidden_size] + b_hh[hidden_size:2*hidden_size])
        g_t = torch.tanh(w_times_x[:, 2*hidden_size:3*hidden_size] + w_times_h_prev[:, 2*hidden_size:3*hidden_size] +
                            b_ih[2*hidden_size:3*hidden_size] + b_hh[2*hidden_size:3*hidden_size])
        o_t = torch.sigmoid(w_times_x[:, 3*hidden_size:4*hidden_size] + w_times_h_prev[:, 3*hidden_size:4*hidden_size] +
                            b_ih[3*hidden_size:4*hidden_size] + b_hh[3*hidden_size:4*hidden_size])

        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)   # [bs, hidden_size]

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)
  • 使用pytorch中LSTM的内置参数测试该函数
c_output, (c_hn, c_cn) = lstm_forward(input, lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0,
                                      lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0, (h0, c0))
print("----手写LSTM的输出----")
print(c_output)
print(c_hn)
print(c_cn)

查看两者的输出结果完全相同
在这里插入图片描述

1.3 手写一个lstmp_forward函数 实现单向LSTMP的计算原理

def lstmp_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    h0, c0 = initial_states
    bs, T, input_size = input.shape
    hidden_size = w_ih.shape[0]//4

    prev_h = h0 # [bs, proj_size]
    prev_c = c0 # [bs, hidden_size]

    # w_ih维度:[4*h_size, input_size] w_hh维度:[4*h_size, proj_size]    w_hr维度:[proj_size, h_size]
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)


    if w_hr is not None:
        proj_size = w_hr.shape[0]
        output_size = proj_size
        batch_w_hr = w_hr.unsqueeze(0).tile(bs, 1, 1)
    else:
        output_size = hidden_size

    output = torch.zeros(bs, T, output_size)

    for t in range(T):
        x = input[:, t, :]  # [bs, input_size]1

        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)).squeeze(-1) # [bs, 4*input_size]
        w_times_prev_h = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)).squeeze(-1)    #[ba, 4*hidden_size]

        # 分明计算输入门(i),遗忘门(f),cell门(c), 输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :hidden_size] + w_times_prev_h[:, :hidden_size] + b_ih[:hidden_size] + b_hh[:hidden_size])
        f_t = torch.sigmoid(w_times_x[:, hidden_size:2*hidden_size] + w_times_prev_h[:, hidden_size:2*hidden_size]
                            + b_ih[hidden_size:2*hidden_size] + b_hh[hidden_size:2*hidden_size])
        g_t = torch.tanh(w_times_x[:, 2*hidden_size:3*hidden_size] + w_times_prev_h[:, 2*hidden_size:3*hidden_size]
                            + b_ih[2*hidden_size:3*hidden_size] + b_hh[2*hidden_size:3*hidden_size])
        o_t = torch.sigmoid(w_times_x[:, 3*hidden_size:4*hidden_size] + w_times_prev_h[:, 3*hidden_size:4*hidden_size]
                            + b_ih[3*hidden_size:4*hidden_size] + b_hh[3*hidden_size:4*hidden_size])

        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t *torch.tanh(prev_c) #[bs, hidden_size]

        if w_hr is not None:
            #  w_hr维度:[proj_size, h_size]
            prev_h = torch.bmm(batch_w_hr, prev_h.unsqueeze(-1))      # [bs, proj_size, 1]
            prev_h = prev_h.squeeze(-1) # [bs, proj_size]

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)
  • 使用pytorch中LSTM的内置参数测试该函数
c_output, (c_h_n, c_c_n) = lstmp_forward(input, (h0, c0), layer_LSTMP.weight_ih_l0, layer_LSTMP.weight_hh_l0,
                                         layer_LSTMP.bias_ih_l0, layer_LSTMP.bias_hh_l0, layer_LSTMP.weight_hr_l0)
print("---手写lstmp_forward函数输出---")
print(c_output)

查看两者的输出结果完全相同
在这里插入图片描述

总结

在本次学习中,通过对LSTM运算过程的代码逐行实现,了解到LSTM模型,以及加深了自己对LSTM模型的理解与推导。在上次的学习中,我学习到RNN不能处理长依赖问题,而LSTM是一种特殊的RNN,比较适用于解决长依赖问题,而且LSTM与RNN一样是链式结构,但是结构不相同,有4个input,分别指外界存储到Memory里面的值和3个gate(Input Gate、Forget Gate和Output Gate)的信号,它们都是以比较简单的方式起作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/890374.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙--知乎评论

这里我们将采用组件化的思想进行开发 在开发中默认展示的是首页也就是 pages/Index.ets页面 这里存放的是所有页面的配置文件,类似与uniapp中的pages.json 如果我们此时要更改默认显示Zh

jmeter入门: 安装

前提&#xff1a; 安装jdk1.8&#xff0c; 并设置java_home 和path环境变量。 ​​​​​​1. download Apache JMeter - Download Apache JMeter 2. 解压jmeter包 3. 安装插件Install :: JMeter-Plugins.org 下载jar包&#xff0c;放到lib/ext目录 4. 打开jmeter &#xff0…

安装Node.js环境,安装vue工具

一、安装Node.js 去官方网站自行安装自己所需求的安装包 这是下载的官方网站 下载 | Node.js 中文网 给I accept the terms in the License Agreement打上勾然后点击Next 把安装包放到自己所知道的位置,后面一直点Next即可 等待它安装好 然后winr打开命令提示符cmd 二、安装…

解决报错:Invalid number of channels [PaErrorCode -9998]

继昨天重装了树莓派系统后&#xff0c;今天开始重新安装语音助手。在测试录音代码时遇到了报错“Invalid number of channels [PaErrorCode -9998]”&#xff0c;这是怎么回事&#xff1f; 有人说这是因为pyaudio没有安装成功造成的。于是&#xff0c;我pip3 install –upgrad…

难点:Linux 死机定位(进程虚拟地址空间耗尽)

死机定位(进程虚拟地址空间耗尽) 一、死机现象 内存富裕,但内存申请失败。 死机时打印: 怀疑是: 1、内存碎片原因导致。 2、进程虚拟地址空间耗尽导致。 3、进程资源限制导致。 二、内存碎片分析 1、理论知识:如何分析内存碎片化情况 使用 /proc/buddyinfo: /proc/…

数据结构-串

串的定义 串的操作 字符集编码 串的顺序存储 串的链式存储 模式匹配

完成Sentinel-Dashboard控制台数据的持久化-同步到Nacos

本次案例采用的是Sentinel1.8.8版本 一、Sentinel源码环境搭建 1、下载Sentinel源码工程 git clone https://github.com/alibaba/Sentinel.git 2、导入到idea 这里可以先运行DashboardApplication.java试一下是否运行成功&#xff0c;若成功&#xff0c;源码环境搭建完毕&a…

树莓派应用--AI项目实战篇来啦-11.OpenCV定位物体的实时位置

1. 介绍 本项目通过PCA9685舵机控制模块控制二自由度舵机云台固定在零点位置&#xff0c;然后通OpenCV检测到黄色小熊&#xff0c;找到中心位置并打印出中心位置的坐标&#xff0c;通过双色LED灯进行指示是否检测到目标&#xff0c;本项目为后面二维云台追踪物体和追踪人脸提供…

图论day56|广度优先搜索理论基础 、bfs与dfs的对比(思维导图)、 99.岛屿数量(卡码网)、100.岛屿的最大面积(卡码网)

图论day56|广度优先搜索理论基础 、bfs与dfs的对比&#xff08;思维导图&#xff09;、 99.岛屿数量&#xff08;卡码网&#xff09;、100.岛屿的最大面积&#xff08;卡码网&#xff09;&#xff09; 广度优先搜索理论基础bfs与dfs的对比&#xff08;思维导图&#xff09;&…

关于Linux下C++程序内存dump的分析和工具

前言 程序崩溃令人很崩溃&#xff0c;特别是让人找不到原因的崩溃&#xff0c;但是合适的工具可以帮助人很快的定位到问题&#xff0c;在AI基础能力ASR服务开发时&#xff0c;找到了一种比较实用和简单的内存崩溃的dump分析工具breakpad&#xff0c; 可以帮助在Linux下C开发程…

C语言初阶-数据类型和变量【下】

紧接上期------------------------->>>C语言初阶-数据类型和变量【上】 全局变量和局部变量在内存中存储在哪⾥呢&#xff1f; ⼀般我们在学习C/C语⾔的时候&#xff0c;我们会关注内存中的三个区域&#xff1a; 栈区 、 堆区 、 静态区 。 内存的分配情况 局部变量是…

Java->排序

目录 一、排序 1.概念 2.常见的排序算法 二、常见排序算法的实现 1.插入排序 1.1直接插入排序 1.2希尔排序(缩小增量法) 1.3直接插入排序和希尔排序的耗时比较 2.选择排序 2.1直接选择排序 2.2堆排序 2.3直接选择排序与堆排序的耗时比较 3.交换排序 3.1冒泡排序…

肺腺癌预后新指标:全切片图像中三级淋巴结构密度的自动化量化|文献精析·24-10-09

小罗碎碎念 本期这篇文章&#xff0c;我去年分享过一次。当时发表在知乎上&#xff0c;没有标记参考文献&#xff0c;配图的清晰度也不够&#xff0c;并且分析的还不透彻&#xff0c;所以趁着国庆假期重新分析一下。 这篇文章的标题为《Computerized tertiary lymphoid structu…

【实战】Nginx+Lua脚本+Redis 实现自动封禁访问频率过高IP

大家好&#xff0c;我是冰河~~ 自己搭建的网站刚上线&#xff0c;短信接口就被一直攻击&#xff0c;并且攻击者不停变换IP&#xff0c;导致阿里云短信平台上的短信被恶意刷取了几千条&#xff0c;加上最近工作比较忙&#xff0c;就直接在OpenResty上对短信接口做了一些限制&am…

《Linux运维总结:基于ARM64+X86_64架构CPU使用docker-compose一键离线部署mongodb 7.0.14容器版分片集群》

总结&#xff1a;整理不易&#xff0c;如果对你有帮助&#xff0c;可否点赞关注一下&#xff1f; 更多详细内容请参考&#xff1a;《Linux运维篇&#xff1a;Linux系统运维指南》 一、部署背景 由于业务系统的特殊性&#xff0c;我们需要面向不通的客户安装我们的业务系统&…

C++入门基础知识110—【关于C++ if...else 语句】

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///C爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于C if...else 语句的相关内容&#xff01…

数据结构-5.2.树的性质

一.树的常考性质&#xff1a; 性质1&#xff1a;结点数 总度数 1(结点的度&#xff1a;结点分支的数量) 一个分支中&#xff0c;如父结点B&#xff0c;两个子结点为E和F&#xff0c;结点B的度的值为2&#xff0c;等于子结点数量&#xff0c;加上这一个父结点(父结点只能有一…

部署私有仓库以及docker web ui应用

官方地址&#xff1a;https://hub.docker.com/_/registry/tags 一、拉取registry私有仓库镜像 docker pull registry:latest 二、运⾏容器 docker run -itd -v /home/dockerdata/registry:/var/lib/registry --name "pri_registry1" --restartalways -p 5000:5000 …

数据结构-5.5.二叉树的存储结构

一.二叉树的顺序存储&#xff1a; a.完全二叉树&#xff1a; 1.顺序存储中利用了静态数组&#xff0c;空间大小有限&#xff1a; 2.基本操作&#xff1a; (i是结点编号) 1.上述图片中i所在的层次后面的公式应该把n换成i(图片里写错了)&#xff1b; 2.上述图片判断i是否有左…

ClickHouse的原理及使用,

1、前言 一款MPP查询分析型数据库——ClickHouse。它是一个开源的&#xff0c;面向列的分析数据库&#xff0c;由Yandex为OLAP和大数据用例创建。ClickHouse对实时查询处理的支持使其适用于需要亚秒级分析结果的应用程序。ClickHouse的查询语言是SQL的一种方言&#xff0c;它支…