66、基于长短期记忆 (LSTM) 网络对序列数据进行分类

1、基于长短期记忆 (LSTM) 网络对序列数据进行分类的原理及流程

基于长短期记忆(LSTM)网络对序列数据进行分类是一种常见的深度学习任务,适用于处理具有时间或序列关系的数据。下面是在Matlab中使用LSTM网络对序列数据进行分类的基本原理和流程:

  1. 准备数据

    • 确保数据集中包含带有标签的序列数据,例如时间序列数据、文本数据等。
    • 将数据进行预处理和归一化,以便输入到LSTM网络中。
  2. 构建LSTM网络

    • 在Matlab中,可以使用内置函数 lstmLayer 来构建LSTM层。
    • 指定输入数据维度、LSTM单元数量、输出层大小等参数。
    • 通过 layers = [sequenceInputLayer(inputSize), lstmLayer(numHiddenUnits), fullyConnectedLayer(numClasses), classificationLayer()] 构建完整的LSTM分类网络。
  3. 定义训练选项

    • 设置训练选项,例如学习率、最大迭代次数、小批量大小等。
    • 使用 trainingOptions 函数来定义训练选项。
  4. 训练网络

    • 使用 trainNetwork 函数来训练构建好的LSTM网络。
    • 输入训练数据和标签,并使用定义好的训练选项进行训练。
  5. 评估网络性能

    • 使用测试数据评估训练好的网络的性能,可以计算准确率、混淆矩阵等。
    • 通过 classify 函数对新数据进行分类预测。
  6. 模型调优

    • 可以通过调整LSTM网络结构、训练参数等进行进一步优化模型性能。

在实际的应用中,可以根据具体数据和任务需求对LSTM网络进行调整和优化,以获得更好的分类性能。Matlab提供了丰富的工具和函数来支持LSTM网络的构建、训练和评估,利用这些工具可以更高效地完成序列数据分类任务。

2、基于长短期记忆 (LSTM) 网络对序列数据进行分类说明

使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。

 

3、加载序列数据

1)说明

使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。

从 WaveformData 加载示例数据。

序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。

2)加载数据代码

load WaveformData 

3)绘制部分序列

代码

numChannels = size(data{1},2);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end

视图效果

5e726cb5c6544a79826ce1ec2d56f905.png

4)查看分类

实现代码

classNames = categories(labels)

classNames = 4×1 cell
    {'Sawtooth'}
    {'Sine'    }
    {'Square'  }
    {'Triangle'}

5)划分数据

说明

使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据)

实现代码 

numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);

4、准备要填充的数据

1)说明

默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度

2)获取观测值序列长度代码

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,1);
end

3)序列长度排序代码

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);

4)查看序列长度

代码

figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

视图效果

28fa6a44a83743e8b30660fcbb63a929.png

 

5、定义 LSTM 神经网络架构

1)说明

将输入大小指定为输入数据的通道数。

指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。

最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。

2)实现代码

numHiddenUnits = 120;
numClasses = 4;

layers = [
    sequenceInputLayer(numChannels)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]

layers = 
  4×1 Layer array with layers:

     1   ''   Sequence Input    Sequence input with 3 dimensions
     2   ''   BiLSTM            BiLSTM with 120 hidden units
     3   ''   Fully Connected   4 fully connected layer
     4   ''   Softmax           softmax

6、指定训练选项

1)说明

使用 Adam 求解器进行训练。

进行 200 轮训练。

指定学习率为 0.002。

使用阈值 1 裁剪梯度。

为了保持序列按长度排序,禁用乱序。

在图中显示训练进度并监控准确度。

2)实现代码

options = trainingOptions("adam", ...
    MaxEpochs=200, ...
    InitialLearnRate=0.002,...
    GradientThreshold=1, ...
    Shuffle="never", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

7、训练 LSTM 神经网络

1)说明

使用 trainnet 函数训练神经网络

2)实现代码

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

3)视图效果 

cf4856b632c144a58ea5c904d194e96d.png

8、测试 LSTM 神经网络

1)对测试数据进行分类,并计算预测的分类准确度。

numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,1);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);

2)对测试数据进行分类,并计算预测的分类准确度。

scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);

3)计算分类准确度

acc = mean(YTest == TTest)

acc = 0.8700

4)混淆图中显示分类结果

figure
confusionchart(TTest,YTest)

0b999fc6591a4438bf3efd7c41517798.png 

9、总结

基于长短期记忆(LSTM)网络对序列数据进行分类是一种重要的深度学习任务,适用于处理具有序列关系的数据,如时间序列数据、自然语言处理等。以下是对使用LSTM网络进行序列数据分类的总结:

  1. LSTM网络结构

    • LSTM是一种适用于处理长期依赖问题的循环神经网络(RNN)变种,能够有效地捕捉序列数据中的长期依赖关系。
    • LSTM网络包含输入门、遗忘门、输出门等核心部分,通过这些门控机制来控制信息的输入、遗忘和输出。
  2. 数据准备

    • 准备带有标签的序列数据,确保数据格式正确且包含标签信息。
    • 进行数据预处理和归一化操作,以便于网络训练。
  3. 网络构建

    • 使用深度学习框架(如TensorFlow、Pytorch或Matlab)构建LSTM网络,定义输入层、LSTM层、全连接层和输出层。
    • 设置网络参数,包括输入维度、LSTM单元个数、输出类别数等。
  4. 模型训练

    • 使用标记好的数据集对构建好的LSTM网络进行训练。
    • 设置优化器、损失函数和训练参数,如学习率、迭代次数等。
    • 调整网络参数以提高模型性能,避免过拟合。
  5. 模型评估

    • 使用验证集或测试集对训练好的模型进行评估,计算准确率、精确率、召回率等指标。
    • 分析模型在不同类别上的表现,进行结果可视化分析。
  6. 模型应用和优化

    • 将训练好的模型用于实际应用中,对新数据进行分类预测。
    • 根据实际需求对模型进行调优和优化,如调整网络结构、训练参数或使用模型集成等方法。

综合来看,基于LSTM网络对序列数据进行分类是一种强大的方法,可在许多领域中发挥作用。通过合理设计网络结构、优化数据准备和训练过程,可以有效地构建出具有良好泛化能力的序列数据分类模型。

10、源代码

代码

%% 基于长短期记忆 (LSTM) 网络对序列数据进行分类
%使用 LSTM 神经网络对序列数据进行分类,LSTM 神经网络将序列数据输入网络,并根据序列数据的各个时间步进行预测。

%% 加载序列数据
%使用 Waveform 数据集,训练数据包含四种波形的时间序列数据。每个序列有三个通道,且长度不同。
%从 WaveformData 加载示例数据。
%序列数据是序列的 numObservations×1 元胞数组,其中 numObservations 是序列数。每个序列都是一个 numTimeSteps×-numChannels 数值数组,其中 numTimeSteps 是序列的时间步,numChannels 是序列的通道数。标签数据是 numObservations×1 分类向量。
load WaveformData 
%绘制部分序列
numChannels = size(data{1},2);

idx = [3 4 5 12];
figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{idx(i)},DisplayLabels="Channel "+string(1:numChannels))
    
    xlabel("Time Step")
    title("Class: " + string(labels(idx(i))))
end
%查看类名称
classNames = categories(labels)
%划分数据
%使用 trainingPartitions 函数将数据划分为训练集(包含 90% 数据)和测试集(包含其余 10% 数据),
numObservations = numel(data);
[idxTrain,idxTest] = trainingPartitions(numObservations,[0.9 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XTest = data(idxTest);
TTest = labels(idxTest);
%% 准备要填充的数据
%默认情况下,软件将训练数据拆分成小批量并填充序列,使它们具有相同的长度
%获取观测值序列长度
numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,1);
end
%序列长度排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);
%查看序列长度
figure
bar(sequenceLengths)
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

%% 定义 LSTM 神经网络架构
%将输入大小指定为输入数据的通道数。
%指定一个具有 120 个隐藏单元的双向 LSTM 层,并输出序列的最后一个元素。
%最后,包括一个输出大小与类的数量匹配的全连接层,后跟一个 softmax 层。
numHiddenUnits = 120;
numClasses = 4;

layers = [
    sequenceInputLayer(numChannels)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]
%% 指定训练选项
%使用 Adam 求解器进行训练。
%进行 200 轮训练。
%指定学习率为 0.002。
%使用阈值 1 裁剪梯度。
%为了保持序列按长度排序,禁用乱序。
%在图中显示训练进度并监控准确度。
options = trainingOptions("adam", ...
    MaxEpochs=200, ...
    InitialLearnRate=0.002,...
    GradientThreshold=1, ...
    Shuffle="never", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);
%%  训练 LSTM 神经网络
%使用 trainnet 函数训练神经网络
net = trainnet(XTrain,TTrain,layers,"crossentropy",options);
%% 测试 LSTM 神经网络
%对测试数据进行分类,并计算预测的分类准确度。
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,1);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);
%对测试数据进行分类,并计算预测的分类准确度。
scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);
%计算分类准确度
acc = mean(YTest == TTest)
%混淆图中显示分类结果
figure
confusionchart(TTest,YTest)

工程文件

https://download.csdn.net/download/XU157303764/89499744

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/760485.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

XJTUSE-数据结构-homework1

任务 1 题目: 排序算法设计: 需要写Selection、Shell、Quicksort 和 Mergesort四种排序算法,书上讲述比较全面而且不需要进行额外的优化,下面我简要地按照自己的理解讲述。 Selection(选择排序)&#xff…

HarmonyOS Next开发学习手册——单选框 (Radio)

Radio是单选框组件,通常用于提供相应的用户交互选择项,同一组的Radio中只有一个可以被选中。具体用法请参考 Radio 。 创建单选框 Radio通过调用接口来创建,接口调用形式如下: Radio(options: {value: string, group: string})…

Linux常用工具使用方式

目录 常用工具: 安装包管理工具: 查找含有关键字的软件包 安装软件 安装文件传输工具 安装编辑器 C语言编译器 C编译器 安装调试器 安装项目版本管理工具 cmake 卸载软件 安装jsoncpp 安装boost库 安装mariadb 安装tree(让目录…

Python28-3 朴素贝叶斯分类算法

朴素贝叶斯算法简介 朴素贝叶斯(Naive Bayes)算法是一种基于贝叶斯定理的分类算法。它广泛应用于文本分类、垃圾邮件检测和情感分析等领域。该算法假设特征之间是独立的,这个假设在实际情况中可能并不完全成立,但Naive Bayes在许…

java笔记(30)——反射的 API 及其 使用

文章目录 反射1. 什么是反射2. 获取class字段(字节码文件对象)方式1方式2方式3应用 3. 获取构造方法和权限修饰符前期准备获取所有的公共构造方法获取所有的构造方法获取无参构造方法获取一个参数的构造方法获取一个参数的构造方法获取两个参数的构造方法…

Java面试题--JVM大厂篇之G1 GC的分区管理方式如何减少应用线程的影响

目录 引言: 正文: 1. 区域划分(Region) 2. 并行和并发回收 3. 区域优先回收(Garbage First) 4. 可预测的停顿时间 5. 分阶段回收 6. 复制和压缩 实际效果: 场景举例 1. 减少单次GC的影响 2. 支持高并发环境 3. 优…

数学建模(1):期末大乱炖

1 概述!! 1.1 原型和模型 原型:客观存在的研究对象称为原型,也称为“系统”、“过程”。 机械系统、电力系统、化学反应过程、生产销售过程等都是原型; 研究原型的结构和原理, 从而进行优化、预测、评价…

一区算法MPA|海洋捕食者算法原理及其代码实现(Matlab/Python))

Matlab/Python: 本文KAU将介绍一个2020年发表在1区期刊ESWA上的优化算法——海洋捕食者算法 (Marine Predators Algorithm,MPA)[1] 该算法由Faramarzi等于2020年提出,其灵感来源于海洋捕食者之间不同的觅食策略、最佳相遇概率策略、海洋记…

【MySQL】Linux下MySQL的目录结构、用户、权限与角色

一、Linux下MySQL的目录结构 1、MySQL相关目录 数据库文件存放路径:/var/lib/mysql数据库命令存放路径:/user/bin和/user/sbin配置文件目录:/usr/share/mysql-8.0/、/usr/share/mysql/和/etc/my.cnf 2、假设我们创建了一个数据库dbtest1&a…

使用evo工具比较ORB-SLAM3的运行轨迹(从安装到解决报错)

ORB-SLAM2和ORB-SLAM3怎么跑出来,之前都有相关的保姆级的教程,下来给大家介绍一款evo工具,给科研加速!!! 文章目录 1.下载evo2.生成轨迹3.evo别的功能使用 1.下载evo 输入命令下载 pip install -i https…

你真的会udf提权???数据库权限到系统权限 内网学习 mysql的udf提权操作 ??msf你会用了吗???

我们在已经取得了数据库的账号密码过后,我们要进一步进行提取的操作,我们mysql有4钟提权的操作。 udf提权(最常用的)mof提权启动项提权反弹shell提权操作 怎么获取密码操作: 怎么获取密码,通过sql注入获取这个大家都应该知道了&a…

百强韧劲,进击新局 2023年度中国医药工业百强系列榜单发布

2024年,经济工作坚持稳中求进、以进促稳、先立后破等工作要求。医药健康行业以不懈进取的“韧劲”,立身破局,迎变启新。通过创新和迭代应对不确定性,进化韧性力量,坚持高质量发展,把握新时代经济和社会给予…

零基础开始学习鸿蒙开发-读书app简单的设计与开发

目录 1.首页设计 2.发现页面的设计 3.设置页面的设计 4.导航页设计 5.总结: 6.最终的效果 1.首页设计 Entry Component export struct home {State message: string 首页build() {Row() {Column() {Text(this.message).fontSize(50).fontWeight(FontWeight.B…

基于线调频小波变换的非平稳信号分析方法(MATLAB)

信号处理领域学者为了改进小波变换在各时频区间能量聚集性不高的缺点,有学者在小波分析基础上引入调频算子构成了线性调频小波变换,线调频小波一方面继承了小波变换的理论完善性,另一方面用一个新的参数(线调频参数)刻…

构建高效业财一体化管理体系

构建高效业财一体化管理体系 业财一体化战略意义 提升决策质量 强化数据支撑:通过整合业务与财务数据,为决策提供准确、实时的信息基础,确保分析的深度与广度。促进业务与财务协同:打破信息孤岛,实现业务流程与财务管…

Django 定义模型执行迁移

1,创建应用 Test/app8 python manage.py startapp app8 2,注册应用 Test/Test/settings.py 3,配置路由 Test/Test/urls.py from django.contrib import admin from django.urls import path, includeurlpatterns [path(app8/, include(a…

Linux服务器上安装CUDA11.2和对应的cuDNN 8.4.0

一、检查 检查本机是否有CUDA工具包,输入nvcc -V: 如图所示,服务器上有CUDA,但版本为9.1.85,版本过低,因此博主要重装一个新的。 二、安装CUDA 1.查看服务器最高支持的CUDA版本 在命令行输入nvidia-smi查看显卡驱动…

Mining Engineering First Aid Riding

4个最主要的日常技能:Mining 采矿 Engineering 工程 First Aid 急救 Riding 骑术 4个最主要的日常技能

C# 信号量的使用

学习来源:《.net core 底层入门》 第六章第9节:信号量 案例:主线程负责添加数据,子线程负责获取数据 使用SemaphoreSlim(轻信号量)实现: using System; using System.Collections.Generic; us…

AI写作变现指南:从项目启动到精通

项目启动 1. 确定目标客户群体 首先,明确谁是我们的目标客户。以下是一些潜在的客户群体: 大学生:他们需要写论文、报告、演讲稿等。 职场人士:包括需要撰写商业计划书、市场分析报告、项目提案等的专业人士。 自媒体从业者&…