双向长短期记忆(Bi-LSTM)神经网络介绍

      长短期记忆(Long Short-Term Memory, LSTM)神经网络

      1.是Hochreiter和Schmidhuber设计的循环神经网络(Recurrent Neural Network, RNN)的改进版本。LSTM模型借鉴了人类大脑的选择性输入和选择性遗忘机制,获取序列中的关键信息,遗忘和当前预测任务无关的信息。

      2.在循环神经网络的基础上,在隐藏层各神经单元中增加记忆单元(memory cell,一种可以长时间保存信息的容器),从而使时间序列上的记忆信息可控。每次在隐藏层各单元间传递时通过三个可控门(遗忘门、输入门、输出门),可以控制之前信息和当前信息的记忆和遗忘程度,从而使LSTM网络具备了长期记忆功能。

      3.在RNN中,有一个很大局限,就是梯度消失和梯度爆炸问题。也就是RNN只能对较短距离的信息进行记忆,随着时间的间隔增大,RNN学习起来就会变的非常的困难。LSTM神经网络的出现很好的解决了这个问题,相对于传统的RNN,LSTM神经网络在其中添加了一个能够记忆长时间信息的单元。

      4.LSTM神经网络一般由输入层,隐藏层和输出层构成LSTM神经网络的核心就是其中的具有记忆功能的隐藏层。如下图所示是具有记忆功能的LSTM单元图:它包括输入门、输出门、遗忘门以及记忆单元块。图中xt代表当前时刻的输入数据,Ct-1代表上一时刻的单元状态(cell state),ht-1代表上一时刻的输出(或上一时刻的隐藏状态(hidden state)),Ct代表当前记忆单元状态,ht代表t时刻的输出值(或当前的隐藏状态)。ft代表遗忘门,it代表输入门,ot代表输出门。其中的输入门以及输出门为控制门,输入门决定给记忆单元传送的信息量,用来控制当前的输入值xt有多少数据被保留在Ct中,从而实现对单元状态Ct的更新;输出门决定将记忆单元中的多少信息传送给当前的输出。遗忘门对记忆单元进行控制,用来决定记忆单元的记忆和遗忘,决定的是上一时刻的记忆单元有多少数据将传递到现在。

      5.LSTM的核心概念是单元状态及其各种门。单元状态充当传输高速公路,将相关信息传输到序列链的下游。你可以将其视为网络的"记忆(memory)"。理论上,单元状态可以在整个序列处理过程中携带相关信息。因此,即使是来自较早时间步(time step)的信息也可以进入后续时间步,从而减少短期记忆(short-term memory)的影响。随着单元状态的旅程,信息会通过门添加到单元状态或从单元状态中删除。门是不同的神经网络,它们决定哪些信息允许进入单元状态。门可以了解在训练期间哪些信息是相关的,需要保留或遗忘。有三个不同的门来调节LSTM单元中的信息流。由三个门控制记忆单元。这使得LSTM网络能够在信息流经网络时有选择地保留或丢弃信息,从而使它们能够学习长期依赖关系。

      (1).输入门(it):控制向记忆单元添加哪些信息。为了更新单元状态,我们有输入门。首先,我们将之前的隐藏状态和当前输入传递到sigmoid函数中。通过将值转换为0到1之间的值,该函数决定更新哪些值。0表示不重要,1表示重要。你还可以将隐藏状态和当前输入传递到tanh函数中,以将值压缩到-1和1之间,以帮助调节网络。然后将tanh输出与sigmoid输出相乘。sigmoid输出将决定哪些信息对于保留在tanh输出中很重要。

      (2).遗忘门(ft):控制从记忆单元移除哪些信息。该门决定应该丢弃或保留哪些信息,来自前一个隐藏状态的信息和来自当前输入的信息通过sigmoid函数传递。值介于0和1之间。越接近0表示丢弃,越接近1表示保留。

      (3).输出门(ot):控制从记忆单元输出哪些信息。输出门决定下一个隐藏状态应该是什么。隐藏状态包含有关先前输入的信息。隐藏状态也用于预测。首先,我们将先前的隐藏状态和当前输入传递给sigmoid函数。然后我们将新修改的单元状态传递给tanh函数。我们将tanh输出与sigmoid输出相乘,以决定隐藏状态应该携带什么信息。输出是隐藏状态。然后,新的单元状态和新的隐藏状态被延续到下一个时间步。

      (4).隐藏状态(ht):充当网络的短期记忆,保存着网络之前见过的先前数据的信息。隐藏状态根据输入、先前的隐藏状态和记忆单元的当前状态进行更新。

      (5).单元状态(ct):首先,单元状态逐点(pointwise)乘以遗忘向量(forget vector)。如果乘以接近0的值,则有可能丢弃单元状态中的值。然后我们从输入门获取输出并进行逐点加法,将单元状态更新为神经网络认为相关的新值。这给了我们新的单元状态。

      遗忘门决定了哪些信息与之前的时间步相关,哪些信息需要保留。输入门决定了哪些信息与当前时间步相关,哪些信息需要添加。输出门决定了下一个隐藏状态应该是什么

      6.LSTM架构中的网络可以堆叠以创建深度架构,从而能够学习序列数据中更复杂的模式和层次结构。LSTM架构具有链式结构。

      7.LSTM使用单元状态来存储有关过去输入的信息。此单元状态在网络的每个步骤中更新,网络使用它来对当前输入进行预测。单元状态使用一系列门进行更新,这些门控制允许多少信息流入和流出单元。

      8.LSTM的控制流程与RNN类似。它在信息前向传播的过程中处理数据。不同之处在于LSTM单元(cell)内的操作。

      双向长短期记忆(Bi-directional Long Short-Term Memory, Bi-LSTM)神经网络

      1.Bi-LSTM是LSTM的扩展,涉及两个并行运行的LSTM,两个LSTM网络组成,前向LSTM和后向LSTM,一个网络处理前向输入序列,另一个网络处理后向输入序列,每个隐藏层的输出由两个LSTM组合而成。

      2.Bi-LSTM神经网络在训练时能同时对当前训练t时刻之前的历史数据和之后的未来数据进行充分利用。

      3.Bi-LSTM基本工作原理:通过前向LSTM和后向LSTM得到两个时间序列相反的隐藏层状态,然后将其连接得到同一个输出,其他步骤与LSTM训练过程类似。前向LSTM 和后向LSTM可以分别获取当前输入序列的前向信息和后向信息。

     注:以上整理的内容主要来自:

      1. https://towardsdatascience.com

      2. 硕论,《基于LSTM的人体连续动作识别》

      PyTorch中torch.nn.LSTM使用说明

     1.声明如下:

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)

      2.公式如下:

      3.参数说明:

      (1).input_size:输入x的特征数量。

      (2).hidden_size:隐藏状态h的特征数量。

      (3).num_layers:循环层的数量。例如,设置num_layers=2意味着将两个LSTM堆叠在一起以形成堆叠的LSTM,第二个LSTM接收第一个LSTM的输出并计算最终结果。默认值为1。

      (4).bias:如果为False,则该层不使用偏置(bias)。默认为True。

      (5).batch_first:如果为True,则输入和输出张量将以(batch, seq, feature)而不是(seq, batch, feature) 的形式提供。这不适用于隐藏状态或单元状态。默认为False。

      (6).dropout:如果非零,则在除最后一层之外的每个LSTM层的输出上引入Dropout层,dropout概率等于dropout。默认值为0.0。

      (7).bidirectional:如果为True,则变为双向LSTM。默认为False。

      (8).proj_size:如果大于0,将使用具有投影的LSTM(LSTM with projections will be used)。默认值为0。

      4.对于inputs和outputs: 假设batch_first=False, batch input,proj_size=0

      (1).inputs: input, (h_0, c_0)

      1).input:(L,N,Hin)

      2).h_0:(D∗num_layers,N,Hout)

      3).c_0:(D∗num_layers,N,Hcell)

      (2).outputs: output, (h_n, c_n)

      1).output: (L,N,D∗Hout)

      2).h_n:(D*num_layers,N,Hout)

      3).c_n:(D*num_layers,N,Hcell)

      其中:

      N = batch size

      L = sequence length

      D = 2 if bidirectional=True otherwise 1

      Hin = input_size

      Hcell = hidden_size

      Hout = hidden_size

      将Bi-LSTM用于分类,示例代码如下:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import colorama
import argparse

def parse_args():
	parser = argparse.ArgumentParser(description="BiLSTM classify")
	parser.add_argument("--epochs", type=int, default=10, help="number of training")
	parser.add_argument("--lr", type=float, default=0.003, help="learning rate")
	parser.add_argument("--batch_size", type=int, default=64, help="batch size during training")
	parser.add_argument("--hidden_size", type=int, default=128, help="hidden state size")
	parser.add_argument("--num_layers", type=int, default=2, help="number of recurrent layers")

	args = parser.parse_args()
	return args

def load_data(batch_size):
	train_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
	test_dataset = torchvision.datasets.MNIST(root='../../data/', train=False, transform=transforms.ToTensor())

	train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
	test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

	return train_loader, test_loader

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
	def __init__(self, input_size, hidden_size, num_layers, num_classes, device):
		super(BiRNN, self).__init__()
		self.hidden_size = hidden_size
		self.num_layers = num_layers
		self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
		self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
		self.device = device

	def forward(self, x):
		# Set initial states
		h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(self.device) # 2 for bidirection, x.size(0)=batch_size
		c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(self.device)

		# Forward propagate LSTM
		out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
		# Decode the hidden state of the last time step
		out = self.fc(out[:, -1, :]) # out: tensor of shape (batch_size, num_classes)

		return out

def train(epochs, lr, batch_size, hidden_size, num_layers, device, input_size, sequence_length, num_classes):
	train_loader, test_loader = load_data(batch_size)

	model = BiRNN(input_size, hidden_size, num_layers, num_classes, device).to(device)

	# Loss and optimizer
	criterion = nn.CrossEntropyLoss()
	optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
	# Train the model
	total_step = len(train_loader)
	for epoch in range(epochs):
		model.train()
		for i, (images, labels) in enumerate(train_loader):
			images = images.reshape(-1, sequence_length, input_size).to(device)
			labels = labels.to(device)

			# Forward pass
			outputs = model(images)
			loss = criterion(outputs, labels)

			# Backward and optimize
			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			if (i+1) % 100 == 0:
				print ("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}".format(epoch+1, epochs, i+1, total_step, loss.item()))

		# Test the model
		model.eval()
		with torch.no_grad():
			correct = 0
			total = 0
			for images, labels in test_loader:
				images = images.reshape(-1, sequence_length, input_size).to(device)
				labels = labels.to(device)
				outputs = model(images)
				_, predicted = torch.max(outputs.data, 1)
				total += labels.size(0)
				correct += (predicted == labels).sum().item()

			print("Test Accuracy of the model on the 10000 test images: {} %".format(100 * correct / total))

	# Save the model checkpoint
	torch.save(model.state_dict(), "model.pth")

if __name__ == "__main__":
	# reference: https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/bidirectional_recurrent_neural_network/main.py
	colorama.init(autoreset=True)
	args = parse_args()

	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	sequence_length = 28
	input_size = 28
	num_classes = 10

	train(args.epochs, args.lr, args.batch_size, args.hidden_size, args.num_layers, device, input_size, sequence_length, num_classes)

	print(colorama.Fore.GREEN + "====== execution completed ======")

      执行结果如下图所示:

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/928018.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

.NET 9 中 LINQ 新增功能实现过程

本文介绍了.NET 9中LINQ新增功能,包括CountBy、AggregateBy和Index方法,并提供了相关代码示例和输出结果,感兴趣的朋友跟随我一起看看吧 LINQ 介绍 语言集成查询 (LINQ) 是一系列直接将查询功能集成到 C# 语言的技术统称。 数据查询历来都表示为简单的…

解决PowerPoint的流程图图标中输入文字位置偏下的问题

解决PowerPoint的流程图图标中输入文字位置偏下的问题 背景 在PowerPoint中,插入流程图形状,并在其内部输入中文字符,是很常规的操作。然而,有时输入文本发现文本整体偏下,靠近流程图下侧。 症状 文字位置偏下的效…

C++基础:list的基本使用

文章目录 1.基本构造和插入删除基本构造和尾插数据迭代器的分类内置排序sort任意位置插入删除 2.链表的合并,去重和剪切链表的合并链表去重链表的剪切 list的本质就是带头双向循环列表 1.基本构造和插入删除 基本构造和尾插数据 与之前vector的方法相同直接调用即可 迭代器的分…

SpringBoot中实现EasyExcel实现动态表头导入(完整版)

前言 最近在写项目的时候有一个需求,就是实现动态表头的导入,那时候我自己也不知道动态表头导入是什么,查询了大量的网站和资料,终于了解了动态表头导入是什么。 一、准备工作 确保项目中引入了处理 Excel 文件的相关库&#xff…

亚马逊云(AWS)使用root用户登录

最近在AWS新开了服务器(EC2),用于学习,遇到一个问题就是默认是用ec2-user用户登录,也需要密钥对。 既然是学习用的服务器,还是想直接用root登录,下面开始修改: 操作系统是&#xff1…

基于Java Springboot武汉市公交路线查询APP且微信小程序

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术:Html、Css、Js、Vue、Element-ui 数据库:MySQL 后端技术:Java、Spring Boot、MyBatis 三、运行环境 开发工具:IDEA/eclipse 微信…

【C++】数组

1.概述 所谓数组,就是一个集合,该集合里面存放了相同类型的数据元素。 数组特点: (1)数组中的每个数据元素都是相同的数据类型。 (2)数组是有连续的内存空间组成的。 2、一维数组 2.1维数组定…

WPF中的VisualState(视觉状态)

以前在设置控件样式或自定义控件时&#xff0c;都是使用触发器来进行样式更改。触发器可以在属性值发生更改时启动操作。 像这样&#xff1a; <Style TargetType"ListBoxItem"><Setter Property"Opacity" Value"0.5" /><Setter …

day04【入门】MySQL学习(1)

目前的学习进度&#xff0c;如上图所示。从晚上开始学习MySQL数据库啦。 目录 1、数据库简介 2、数据集连接及准备工作 3、sql 语言中的注释 4、MySQL中常用数据类型 5、数据库中元素 6、创建表 7、insert插入记录 8、select查询 9、update修改数据 10、delete删除、t…

redis核心命令全局命令 + redis 常见的数据结构 + redis单线程模型

文章目录 一. 核心命令1. set2. get 二. 全局命令1. keys2. exists3. del4. expire5. ttl6. type 三. redis 常见的数据结构及内部编码四. redis单线程模型 一. 核心命令 1. set set key value key 和 value 都是string类型的 对于key value, 不需要加上引号, 就是表示字符串…

Dromara WarmFlow工作流动态指定办理人

Dromara WarmFlow工作流动态指定办理人 背景&#xff1a; 审批任务的办理人&#xff0c;通常是在流程设计器中预先设定好办理人&#xff0c;那如果想要在办理过程中指定办理人呢&#xff1f; 那不得不提一下本次的主角&#xff0c;来自Dromara组织的WarmFlow工作流&#xff0…

理解Parquet文件和Arrow格式:从Hugging Face数据集的角度出发

parquet发音&#xff1a;美 [pɑrˈkeɪ] 镶木地板&#xff1b;拼花木地板 理解Parquet文件和Arrow格式&#xff1a;从Hugging Face数据集的角度出发 引言 在机器学习和大数据处理中&#xff0c;数据的存储和传输格式对于性能至关重要。两种广泛使用的格式是 Parquet 和 Arr…

Kylin Server V10 下 Kafka 集群部署

一、ZooKeeper 集群部署 1、主机规划 主机名 IP 地址 myid 10.8.3.35 1 10.8.3.36 2 10.8.3.37 3 2、拓扑结构 3、部署 (1) 下载Zookeeper [root@localhost ~]# cd /usr/local [root@localhost local]# wget https://www.apache.org/dyn/closer.lua/zookeeper/zookeeper-…

【MySql】navicat连接报2013错误

navicat连接mysql报2013错误 报错信息1、检验Mysql数据库是否安装成功2、对Mysql的配置文件进行修改配置2.1、找到配置文件2.2、Linux下修改配置文本 3、连接进入mysql服务4、在mysql下执行授权命令 报错信息 Navicat连接mysql报2013错误 2013-Lost connection to MYSQL serve…

机器学习——决策树模型

决策树是如何工作的&#xff1f; 假设你在经营一家猫收养中心&#xff0c;并提供了一些功能&#xff0c;你想训练一个分类器来快速告诉你&#xff0c;动物到底是不是猫&#xff0c;这里有10个训练例子&#xff0c;并与这10个例子中的每一个相关联&#xff0c;我们将有关于动物…

65页PDF | 企业IT信息化战略规划(限免下载)

一、前言 这份报告是企业IT信息化战略规划&#xff0c;报告详细阐述了企业在面对新兴技术成熟和行业竞争加剧的背景下&#xff0c;如何通过三个阶段的IT战略规划&#xff08;IT 1.0基础建设、IT 2.0运营效率、IT 3.0持续发展&#xff09;&#xff0c;系统地构建IT管理架构、应…

【C++】—— 从零开始封装 Map 与 Set:实现与优化

人生的态度是&#xff0c;抱最大的希望&#xff0c;尽最大的努力&#xff0c;做最坏的打算。 —— 柏拉图 《理想国》 目录 1、理论基石——深度剖析 BSTree、AVLTree 与 RBTree 的概念区别 2、迭代器机制——RBTree 迭代器的架构与工程实现 3、高级容器设计——Map 与 Set…

番外:HTTP、WebSocket 和 gRPC 协议详解

HTTP、WebSocket 和 gRPC 协议详解 在现代网络编程中&#xff0c;HTTP、WebSocket 和 gRPC 协议是三种常用的通信协议。它们各自有着不同的特点和适用场景。本文将从功能、优缺点、使用场景、工作原理以及发明背景等多个方面深入探讨这三种协议。 HTTP协议 功能 HTTP&#…

单元测试报websocket bean创建失败

在单元测试的基础类上添加注解 使用tomcat容器启动即可 如下图

Samtec 大数据应用科普 | 用于HPC和超级计算的连接器

【摘要/前言】 我们的现代工业世界依赖于原材料。无论是石油、钢铁还是棉花&#xff0c;原材料对工业化世界的发展都具有巨大的重要性。尽管这些有形产品仍然一如既往地重要&#xff0c;但现代社会为我们带来了一种新的原材料&#xff0c;那就是数据。 【数据是未来的原材料】…