Pytorch-LSTM轴承故障一维信号分类(一)

目录

前言

1 数据集制作与加载

1.1 导入数据

第一步,导入十分类数据

第二步,读取MAT文件驱动端数据

第三步,制作数据集

第四步,制作训练集和标签

1.2 数据加载,训练数据、测试数据分组,数据分batch

2 LSTM分类模型和超参数选取

2.1 定义LSTM分类模型

2.2 定义模型参数

2.3 模型结构

3 LSTM模型训练与评估

3.1 模型训练

3.2 模型评估


往期精彩内容:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

Python轴承故障诊断 (一)短时傅里叶变换STFT

Python轴承故障诊断 (二)连续小波变换CWT

Python轴承故障诊断 (三)经验模态分解EMD

Python轴承故障诊断 (四)基于EMD-CNN的故障分类

Python轴承故障诊断 (五)基于EMD-LSTM的故障分类

前言

本文基于凯斯西储大学(CWRU)轴承数据,先经过数据预处理进行数据集的制作和加载,最后通过Python实现LSTM模型对故障数据的分类。凯斯西储大学轴承数据的详细介绍可以参考下文:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

1 数据集制作与加载

1.1 导入数据

参考之前的文章,进行故障10分类的预处理,凯斯西储大学轴承数据10分类数据集:

第一步,导入十分类数据

import numpy as np
import pandas as pd
from scipy.io import loadmat

file_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']

for file in file_names:
    # 读取MAT文件
    data = loadmat(f'matfiles\\{file}')
    print(list(data.keys()))

第二步,读取MAT文件驱动端数据

# 采用驱动端数据
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time',
                'X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):
    # 读取MAT文件
    data = loadmat(f'matfiles\\{file_names[index]}')
    dataList = data[data_columns[index]].reshape(-1)
    data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作数据集

train_set、val_set、test_set 均为按照7:2:1划分训练集、验证集、测试集,最后保存数据

第四步,制作训练集和标签

# 制作数据集和标签
import torch

# 这些转换是为了将数据和标签从Pandas数据结构转换为PyTorch可以处理的张量,
# 以便在神经网络中进行训练和预测。

def make_data_labels(dataframe):
    '''
        参数 dataframe: 数据框
        返回 x_data: 数据集     torch.tensor
            y_label: 对应标签值  torch.tensor
    '''
    # 信号值
    x_data = dataframe.iloc[:,0:-1]
    # 标签值
    y_label = dataframe.iloc[:,-1]
    x_data = torch.tensor(x_data.values).float()
    y_label = torch.tensor(y_label.values.astype('int64')) # 指定了这些张量的数据类型为64位整数,通常用于分类任务的类别标签
    return x_data, y_label

# 加载数据
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')

# 制作标签
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存数据
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 数据加载,训练数据、测试数据分组,数据分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载数据集
def dataloader(batch_size, workers=2):
    # 训练集
    train_xdata = load('trainX_1024_10c')
    train_ylabel = load('trainY_1024_10c')
    # 验证集
    val_xdata = load('valX_1024_10c')
    val_ylabel = load('valY_1024_10c')
    # 测试集
    test_xdata = load('testX_1024_10c')
    test_ylabel = load('testY_1024_10c')

    # 加载数据
    train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),
                                   batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),
                                 batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),
                                  batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    return train_loader, val_loader, test_loader

batch_size = 32
# 加载数据
train_loader, val_loader, test_loader = dataloader(batch_size)

2 LSTM分类模型和超参数选取

2.1 定义LSTM分类模型

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为 32 * 32, 就使输入序列的长度降下来了

2.2 定义模型参数

# 定义模型参数
batch_size = 32
input_dim = 32   # 输入维度为一维信号序列堆叠为 32 * 32
hidden_layer_sizes = [256, 128, 64]
output_dim = 10

model = LSTMclassifier(batch_size, input_dim, hidden_layer_sizes, output_dim)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

2.3 模型结构

3 LSTM模型训练与评估

3.1 模型训练

训练结果

200个epoch,准确率将近96%,LSTM网络分类模型效果良好,继续调参还可以进一步提高分类准确率。

注意调整参数:

  • 可以适当增加 LSTM层数 和每层神经元个数,微调学习率;

  • 增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

3.2 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载模型
# model =torch.load('best_model_lstm.pt')
model = torch.load('best_model_lstm.pt', map_location=torch.device('cpu'))

# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():
    correct_test = 0
    test_loss = 0
    for test_data, test_label in test_loader:
        test_data, test_label = test_data.to(device), test_label.to(device)
        test_output = model(test_data)
        probabilities = F.softmax(test_output, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        correct_test += (predicted_labels == test_label).sum().item()
        loss = loss_function(test_output, test_label)
        test_loss += loss.item()

test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')

Test Accuracy: 0.9570  Test Loss: 0.12100271

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/237538.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++之STL算法(1)

STL容器算法主要由、、组成;   algorithm主要有遍历、比较、交换、查找、拷贝、修改等; 1.遍历容器for_each for_each()函数用于完成容器遍历,函数参数如下: for_each(_InIt _First, _InIt _Last, _Fn _Func) 形参&#xff1a…

mybatis多表映射-延迟加载,延迟加载的前提条件是:分步查询

1、建库建表 create database mybatis-example; use mybatis-example; create table t_book (bid varchar(20) primary key,bname varchar(20),stuid varchar(20) ); insert into t_book values(b001,Java,s001); insert into t_book values(b002,Python,s002); insert into …

基于 librosa和soundfile对音频进行重采样 (VITS 必备)

基于 librosa和soundfile对音频进行重采样 一、前言 在玩bert-vits2的时候有对音频进行重采样的需求,故写了一下批量对音频进行重采样的脚本。 优化点: 根据机器自适应线程数为最多,保证充分利用机器资源,提高速度>30%。支持…

UE引擎 LandscapeGrass 实现机制分析(UE5.2)

前言 随着电脑和手机硬件性能越来越高,游戏越来越追求大世界,而大世界非常核心的一环是植被,目前UE5引擎提供给植被生成的主流两种方式为 手刷植被和LandscapeGrass(WeightMap程序化植被)。当然UE5.3推出新一代PCGFramework 节点程序化生成框…

Android 顶部对齐宽度撑满高度等比例缩放及限制最大最小高度

一 示例 二 代码 <?xml version"1.0" encoding"utf-8"?> <FrameLayout xmlns:android"http://schemas.android.com/apk/res/android"android:layout_width"match_parent"android:layout_height"match_parent&qu…

点评项目——秒杀优化

2023.12.11 上一张的秒杀券下单还可以进行优化&#xff0c;先来回顾一下下单流程&#xff1a; 可以看出流程设计多次查询和操作数据库的操作&#xff0c;并且执行顺序是一个线程串行执行&#xff0c;执行性能是比较低的。 优化方案&#xff1a;我们将判断秒杀库存和校验一人一单…

蓝桥杯周赛 第 1 场 强者挑战赛 6. 小球碰撞【算法赛】(思维题/最长上升子序列LIS)

题目 https://www.lanqiao.cn/problems/9494/learning/?contest_id153 思路来源 Aging代码 题解 二分时间t&#xff0c;第i个小球对应一个起点pi、终点pit*vi的区间&#xff0c;问题转化为&#xff0c; 选最多的区间&#xff0c;使得不存在区间包含&#xff08;即li<l…

第二百零一回 介绍一个三方包open_settings

文章目录 1. 概念介绍2 使用方法3 代码与效果3.1 示例代码3.2 运行效果 4. 经验分享 我们在上一章回中介绍了Form Widget相关的内容&#xff0c;本章回中将介绍Form系列组件的验证与提交功能.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我们在这里说的的验…

【电路笔记】-电位器

电位器 文章目录 电位器1、概述2、电位器类型2.1 旋转电位器2.2 滑块电位器2.3 预设和微调电位器2.4 变阻器 3、电位器示例14、电位器作为分压器5、电位器示例26、变阻器6、滑块变阻器7、线性或对数电位器8、总结 当连接的轴物理旋转时&#xff0c;电位计和变阻器的电阻值会发生…

23种设计模式之装饰者模式(被装饰者,接口层,装饰抽象层,具体装饰者)

23种设计模式之装饰者模式 文章目录 23种设计模式之装饰者模式设计思想装饰者模式的优点装饰者模式的缺点装饰者模式的优化方法UML 解析预设场景 代码释义总结 设计思想 原文:装饰器模式&#xff08;Decorator Pattern&#xff09;允许向一个现有的对象添加新的功能&#xff0…

【EMNLP 2023】面向垂直领域的知识预训练语言模型

近日&#xff0c;阿里云人工智能平台PAI与华东师范大学数据科学与工程学院合作在自然语言处理顶级会议EMNLP2023上发表基于双曲空间和对比学习的垂直领域预训练语言模型。通过比较垂直领域和开放领域知识图谱数据结构的不同特性&#xff0c;发现在垂直领域的图谱结构具有全局稀…

做数据分析为何要学统计学(3)——何为置信区间?它有什么作用?

置信区间是统计学中的一个重要工具&#xff0c;用以使用样本参数()来估计总体均值在某置信水平下的范围。通俗一点讲&#xff0c;如果置信度为95%&#xff08;等价于显著水平a0.05&#xff09;&#xff0c;置信区间为[a,b]&#xff0c;这就意味着总体均值落入该区间的概率为95%…

虹科Pico汽车示波器 | 汽车免拆检修 | 2019款别克GL8豪华商务车前照灯水平调节故障

一、故障现象 一辆2019款别克GL8豪华商务车&#xff0c;搭载LTG发动机&#xff0c;累计行驶里程约为10.7万km。车主反映&#xff0c;车辆行驶过程中组合仪表提示前照灯水平调节故障。 二、故障诊断 接车后试车&#xff0c;起动发动机&#xff0c;组合仪表上提示“前照灯水平调节…

Spring Boot监听redis过期的key

Redis支持过期监听&#xff0c;可以实现监听过期数据&#xff0c;实现过程如下 1、pom依赖 <!-- Redis--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></depend…

ChatGPT/GPT4应用:文本、论文、编程、绘图等,提高工作效率及科研项目开发能力

2023年随着OpenAI开发者大会的召开&#xff0c;最重磅更新当属GPTs&#xff0c;多模态API&#xff0c;未来自定义专属的GPT。微软创始人比尔盖茨称ChatGPT的出现有着重大历史意义&#xff0c;不亚于互联网和个人电脑的问世。360创始人周鸿祎认为未来各行各业如果不能搭上这班车…

深入理解模板引擎:解锁 Web 开发的新境界(下)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

Android蓝牙协议栈fluoride(五) - 设备管理(bt application)

在上一篇Android蓝牙协议栈fluoride(四) - 设备管理(bt interface) 中梳理了设备管理器对上层提供的接口&#xff0c;本文将介绍这些接口的具体实现。 各个模块中采用了API状态机数据收发的方式&#xff0c;介绍设备管理时也将采用这个顺序介绍。 核心数据结构 设备管理的核…

鸿蒙HarmonyOS4.0 入门与实战

一、开发准备: 熟悉鸿蒙官网安装DevEco Studio熟悉鸿蒙官网 HarmonyOS应用开发官网 - 华为HarmonyOS打造全场景新服务 应用设计相关资源: 开发相关资源: 例如开发工具 DevEco Studio 的下载 应用发布: 开发文档:

论文阅读《High-frequency Stereo Matching Network》

论文地址&#xff1a;https://openaccess.thecvf.com/content/CVPR2023/papers/Zhao_High-Frequency_Stereo_Matching_Network_CVPR_2023_paper.pdf 源码地址&#xff1a; https://github.com/David-Zhao-1997/High-frequency-Stereo-Matching-Network 概述 在立体匹配研究领域…

OpenAI承认GPT-4变懒,即将发布修复方案提升性能

目录 1OpenAI承认GPT-4变懒&#xff0c;即将发布修复方案提升性能 2一文秒懂人工智能全球近况 1OpenAI承认GPT-4变懒&#xff0c;即将发布修复方案提升性能 **划重点:** 1. &#x1f92f; 用户反馈:GPT-4使用者抱怨OpenAI破坏了体验&#xff0c;称模型几乎“害怕”提供答案。…