用通俗易懂的方式讲解:LSTM原理及生成藏头诗(Python)

一、基础介绍

1.1 神经网络模型

常见的神经网络模型结构有前馈神经网络(DNN)、RNN(常用于文本 / 时间系列任务)、CNN(常用于图像任务)等等。

前馈神经网络是神经网络模型中最为常见的,信息从输入层开始输入,每层的神经元接收前一级输入,并输出到下一级,直至输出层。整个网络信息输入传输中无反馈(循环)。即任何层的输出都不会影响同级层,可用一个有向无环图表示。图片

1.2 RNN 介绍

循环神经网络(RNN)是基于序列数据(如语言、语音、时间序列)的递归性质而设计的,是一种反馈类型的神经网络,它专门用于处理序列数据,如逐字生成文本或预测时间序列数据(例如股票价格、诗歌生成)。图片

RNN和全连接神经网络的本质差异在于“输入是带有反馈信息的”,RNN除了接受每一步的输入x(t) ,同时还有输入上一步的历史反馈信息——隐藏状态h (t-1) ,也就是当前时刻的隐藏状态h(t) 或决策输出O(t) 由当前时刻的输入 x(t) 和上一时刻的隐藏状态h (t-1) 共同决定。从某种程度,RNN和大脑的决策很像,大脑接受当前时刻感官到的信息(外部的x(t) )和之前的想法(内部的h (t-1) )的输入一起决策。

图片

RNN的结构原理可以简要概述为两个公式

RNN的隐藏状态为:h(t) = f( U * x(t) + W * h(t-1) + b1), f为激活函数,常用tanh、relu;

RNN的输出为:o(t) = g( V * h(t) + b2),g为激活函数,当用于分类任务,一般用softmax;

1.3 从RNN到LSTM

但是在实际中,RNN在长序列数据处理中,容易导致梯度爆炸或者梯度消失,也就是长期依赖(long-term dependencies)问题,其根本原因就是模型“记忆”的序列信息太长了,都会一股脑地记忆和学习,时间一长,就容易忘掉更早的信息(梯度消失)或者崩溃(梯度爆炸)。

梯度消失:历史时间步的信息距离当前时间步越长,反馈的梯度信号就会越弱(甚至为0)的现象,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。

改善措施:可以使用 ReLU 激活函数;门控RNN 如GRU、LSTM 以改善梯度消失。

梯度爆炸:网络层之间的梯度(值大于 1)重复相乘导致的指数级增长会产生梯度爆炸,导致模型无法有效学习。

改善措施:可以使用 梯度截断;引导信息流的正则化;ReLU 激活函数;门控RNN 如GRU、LSTM(和普通 RNN 相比多经过了很多次导数都小于 1激活函数,因此 LSTM 发生梯度爆炸的频率要低得多)以改善梯度爆炸。

所以,如果我们能让 RNN 在接受上一时刻的状态和当前时刻的输入时,有选择地记忆和遗忘一部分内容(或者说信息),问题就可以解决了。比如上上句话提及”我去考试了“,然后后面提及”我考试通过了“,那么在此之前说的”我去考试了“的内容就没那么重要,选择性地遗忘就好了。这也就是长短期记忆网络(Long Short-Term Memory, LSTM)的基本思想。

二、LSTM原理

LSTM是种特殊RNN网络,在RNN的基础上引入了“门控”的选择性机制,分别是遗忘门、输入门和输出门,从而有选择性地保留或删除信息,以能够较好地学习长期依赖关系。如下图RNN(上) 对比 LSTM(下):

图片

2.1 LSTM的核心

在RNN基础上引入门控后的LSTM,结构看起来好复杂!但其实LSTM作为一种反馈神经网络,核心还是历史的隐藏状态信息的反馈,也就是下图的Ct:图片对标RNN的ht隐藏状态的更新,LSTM的Ct只是多个些“门控”删除或添加信息到状态信息。由下面依次介绍LSTM的“门控”:遗忘门,输入门,输出门的功能,LSTM的原理也就好理解了。

2.2 遗忘门

LSTM 的第一步是通过"遗忘门"从上个时间点的状态Ct-1中丢弃哪些信息。

具体来说,输入Ct-1,会先根据上一个时间点的输出ht-1和当前时间点的输入xt,并通过sigmoid激活函数的输出结果ft来确定要让Ct-1,来忘记多少,sigmoid后等于1表示要保存多一些Ct-1的比重,等于0表示完全忘记之前的Ct-1。图片

2.3 输入门

下一步是通过输入门,决定我们将在状态中存储哪些新信息。

我们根据上一个时间点的输出ht-1和当前时间点的输入xt 生成两部分信息i t 及C~t,通过sigmoid输出i t,用tanh输出C~t。之后通过把i t 及C~t两个部分相乘,共同决定在状态中存储哪些新信息。图片

在输入门 + 遗忘门控制下,当前时间点状态信息Ct为:

图片

2.4 输出门

最后,我们根据上一个时间点的输出ht-1和当前时间点的输入xt 通过sigmid 输出Ot,再根据Ot 与 tanh控制的当前时间点状态信息Ct 相乘作为最终的输出。图片

综上,一张图可以说清LSTM原理:图片

三、LSTM简单写诗

本节项目利用深层LSTM模型,学习大小为10M的诗歌数据集,自动可以生成诗歌。图片

如下代码构建LSTM模型。如需完整代码,文末获取

图片

model = tf.keras.Sequential([
    # 不定长度的输入
    tf.keras.layers.Input((None,)),
    # 词嵌入层
    tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
    # 第一个LSTM层,返回序列作为下一层的输入
    tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
    # 第二个LSTM层,返回序列作为下一层的输入
    tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
    # 对每一个时间点的输出都做softmax,预测下一个词的概率
    tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
])

# 查看模型结构
model.summary()
# 配置优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

模型训练,考虑训练时长,就简单训练2个epoch。图片

class Evaluate(tf.keras.callbacks.Callback):
    """
    训练过程评估,在每个epoch训练完成后,保留最优权重,并随机生成SHOW_NUM首古诗展示
    """

    def __init__(self):
        super().__init__()
        # 给loss赋一个较大的初始值
        self.lowest = 1e10

    def on_epoch_end(self, epoch, logs=None):
        # 在每个epoch训练完成后调用
        # 如果当前loss更低,就保存当前模型参数
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            model.save(BEST_MODEL_PATH)
        # 随机生成几首古体诗测试,查看训练效果
        print("cun'h")
        for i in range(SHOW_NUM):
            print(generate_acrostic(tokenizer, model, head="春花秋月"))

# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=TRAIN_EPOCHS,
                    callbacks=[Evaluate()])

加载简单训练的LSTM模型,输入关键字(如:算法进阶)后,自动生成藏头诗。可以看出诗句粗略看上去挺优雅,但实际上经不起推敲。后面增加训练的epoch及数据集应该可以更好些。

# 加载训练好的模型
model = tf.keras.models.load_model(BEST_MODEL_PATH)

keywords = input('输入关键字:\n')


# 生成藏头诗
for i in range(SHOW_NUM):
    print(generate_acrostic(tokenizer, model, head=keywords),'\n')

图片

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了NLP面试与技术交流群, 想要进交流群、需要本文源码、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、添加微信号:mlc2060,备注:技术交流
方式②、微信搜索公众号:机器学习社区,后台回复:技术交流

资料
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/297458.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

软件测试|全面解析Docker Start/Stop/Restart命令:管理容器生命周期的必备工具

简介 Docker是一种流行的容器化平台&#xff0c;用于构建、分发和运行应用程序。在使用Docker时&#xff0c;经常需要管理容器的生命周期&#xff0c;包括启动、停止和重启容器。本文将详细介绍Docker中的docker start、docker stop和docker restart命令&#xff0c;帮助您全面…

Hadoop集群三节点搭建(二)

一、克隆三台主机&#xff08;hadoop102 hadoop103 hadoop104&#xff09; 以master为样板机克隆三台出来&#xff0c;克隆前先把master关机 按照上面的步骤克隆其他两个就可以了&#xff0c;记得修改ip和hostname 二、编写集群同步脚本 在/home/attest/ 创建bin目录&…

Linux第9步_通过终端查看U盘文件

学习完“USB设置”后&#xff0c;我们学习通过终端来查看U盘文件。前面讲解过使用鼠标打开U盘&#xff0c;但是在实际使用中&#xff0c;更多的还是采用命令来实现对U盘的操作。 1、在桌面&#xff0c;右击鼠标&#xff0c;弹出下面的界面: 2、点击上图中的“打开终端”&#…

2024年了,是该学学Three.js了

前言 &#x1f4eb; 大家好&#xff0c;我是南木元元&#xff0c;热衷分享有趣实用的文章&#xff0c;希望大家多多支持&#xff0c;一起进步&#xff01; &#x1f345; 个人主页&#xff1a;南木元元 目录 Three.js介绍 Three.js应用场景 搭建开发环境 初始化项目 创建文…

1.2作业

温湿度数据通过中断处理显示到数码管中 main.c #include "spi.h"#include"si7006.h"int main(){int i0,j0,m0,n0;int num[10] {0xFC,0x60,0xDA,0xF2,0x66,0xB6,0x3E,0xE0,0xFE,0xF6};SPI_init();unsigned short hum;short tem;//进行si7006的初始化si700…

如何让自己的写的程序在阿里云一直运行

购买了阿里云服务器后&#xff0c;每次要用自己写在阿里云的服务器程序都要连接到云端 然后./运行该程序&#xff0c;而且每次一断开终端&#xff0c;该服务器就会自动停止&#xff0c;这样使用相当麻烦。那怎样才能让我们的服务器一直在云端后台运行&#xff0c;即便退出终端…

Linux第19步_安装“Ubutun交叉编译工具链”

由于Ubuntu系统使用的GCC编译器&#xff0c;编译结果是X86文件&#xff0c;只能在X86上运行&#xff0c;不能在ARM上直接运行。因此&#xff0c;还要安装一个“Ubutun交叉编译工具链”&#xff0c;才可以在ARM上运行。 arm-none-linux-gnueabi-gcc是 Codesourcery 公司&#x…

【力扣每日一题】力扣2478从链表中移除节点

题目来源 2478.从链表中移除节点 题目描述 给你一个链表的头节点 head 。 移除每个右侧有一个更大数值的节点。 返回修改后链表的头节点 head 。 示例 示例1&#xff1a; 输入&#xff1a;head [5,2,13,3,8] 输出&#xff1a;[13,8] 解释&#xff1a;需要移除的节点是 5 …

Java并发集合详解

第1章&#xff1a;引言 大家好&#xff0c;我是小黑&#xff0c;在这篇博客中&#xff0c;咱们将一起深入探索Java中的并发集合。多线程编程是一个不可或缺的部分&#xff0c;它能让程序运行得更快&#xff0c;处理更多的任务。但同时&#xff0c;多线程也带来了一些挑战&…

unity编辑器Scene界面输出位置及路径

工程Asset下新建Editor文件夹&#xff1b; Editor文件夹下新建脚本LogPosition using System.Collections; using System.Collections.Generic; using UnityEditor; using UnityEngine; public class LogPosition : EditorWindow {//最终输出的数据.static string logtext;//增…

大学生搜题软件,未来可期吗?

作为一家专注于软件开发的公司《智创有术》&#xff0c;我们致力于为客户提供创新、高效和可靠的解决方案。通过多年的经验和专业知识&#xff0c;我们已经在行业内建立了良好的声誉&#xff0c;并赢得了客户的信任和支持。 支持各种源码&#xff0c;网站搭建&#xff0c;APP&a…

32.virtual reality system concepts illustrated using OSVR

32.1 Common Space This section describes the spaces needed to support viewing and interacting with the virtual world. 本节介绍支持查看虚拟世界并与之交互所需的空间。 The spaces required for supporting viewing and interacting with a virtual world can vary …

记一次服务器被入侵的排查过程

起因 阿里云安全中心报告了告警信息&#xff0c;同时手机短信、邮件、电话也接收到了来自阿里云的风险通知&#xff0c;感觉这方面阿里云还是不错。 排查及解决过程 这条wget指令究竟是怎么被运行的 我无法定位到攻击人员是通过什么样的方式让我的java程序执行了wget这条指…

转后端一年半双非本科Java无实习进大厂,给双非朋友经验分享

背景介绍 B站有详细视频&#xff0c;同名搜索即可。 今天文章想分享的是我踩过的坑以及那些做的是值得大家参考。 有需要就加V&#xff1a; zhazhagao_ 进了快手(如果你觉得不是大厂那就不是!)&#xff1a; 真双非本科: 安徽某双非无实习: 因为编程语言问题,去过之后发现不喜欢…

Java中请求生成唯一追溯TraceId

Java中请求生成唯一追溯TraceId 一&#xff1a;背景 因为是微服务架构,平常日志太多,看日志不太好查,所以想要从一整个链路当中获取一个唯一标识,比较好定位问题&#xff0c; 原理就是从gateway网关将标识传递到下游,下游服务拿到这个标识,响应结束后将traceId反向写入响应体…

C# Onnx Chinese CLIP 通过一句话从图库中搜出来符合要求的图片

目录 效果 生成图片特征 查找踢足球的人 测试图片 模型信息 image_model.onnx text_model.onnx 项目 代码 Form1.cs Clip.cs 下载 C# Onnx Chinese CLIP 通过一句话从图库中搜出来符合要求的图片 效果 生成图片特征 查找踢足球的人 测试图片 模型信息 image_mod…

RFM会员价值度模型

模型基本原理 会员价值度用来评估用户的价值情况&#xff0c;是区分会员价值的重要模型和参考依据&#xff0c;也是衡量不同营销效果的关键指标。 价值度模型一般基于交易行为产生&#xff0c;衡量的是有实体转化价值的行为。常用的价值度模型是RFM RFM模型是根据会员 最近…

如何在 Windows 上从电脑硬盘恢复照片

如今&#xff0c;随着相机设备的普及&#xff0c;您可以轻松地一次拍摄一堆照片&#xff0c;将它们传输到硬盘上&#xff0c;然后再拍摄更多照片。但是&#xff0c;如果您的所有照片意外丢失在驱动器中怎么办&#xff1f;你能恢复它们吗&#xff1f; 在本指南中&#xff0c;我…

配置文件的创建和部署

配置描述文件&#xff08;Configuration Profiles&#xff09;的格式为xml&#xff0c;其提供了一个非常容易的方式去给电脑、移动设备或用户定义一些设置或限制&#xff0c;你可以使用Jamf Pro去创建这样的配置文件。 &#xff08;配置文件的负载&#xff09; 有关配置文件的配…

Spring——Spring AOP1(代理模式Proxy)

代理&#xff08;Proxy&#xff09;模式 1.创建工程 2.代理&#xff08;Proxy&#xff09;模式介绍 作用&#xff1a;通过代理可以控制访问某个对象的方法&#xff0c;在调用这个方法前做前置处理&#xff0c;调用这个方法后做后置处理。&#xff08;即&#xff1a; AOP的微观…