Pytorch-CNN轴承故障一维信号分类(二)

目录

前言

1 数据集制作与加载

1.1 导入数据

1.2 数据加载,训练数据、测试数据分组,数据分batch

2 CNN-2D分类模型和训练、评估

2.1 定义CNN-2d分类模型

2.2 定义模型参数

2.3 模型结构

2.4 模型训练

2.5 模型评估

3 CNN-1D分类模型和训练、评估

3.1 定义CNN-1d分类模型

3.2 定义模型参数

3.3 模型结构

3.4 模型训练

3.5 模型评估

4 模型对比


往期精彩内容:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

Python轴承故障诊断 (一)短时傅里叶变换STFT

Python轴承故障诊断 (二)连续小波变换CWT-CSDN博客

Python轴承故障诊断 (三)经验模态分解EMD-CSDN博客

Python轴承故障诊断 (四)基于EMD-CNN的故障分类-CSDN博客

Python轴承故障诊断 (五)基于EMD-LSTM的故障分类-CSDN博客

Pytorch-LSTM轴承故障一维信号分类(一)-CSDN博客

前言

本文基于凯斯西储大学(CWRU)轴承数据,先经过数据预处理进行数据集的制作和加载,最后通过Pytorch实现CNN模型一维卷积和二维卷积对故障数据的分类,然后进行对比。凯斯西储大学轴承数据的详细介绍可以参考下文:

Python-凯斯西储大学(CWRU)轴承数据解读与分类处理

1 数据集制作与加载

1.1 导入数据

参考之前的文章,进行故障10分类的预处理,凯斯西储大学轴承数据10分类数据集:

第一步,导入十分类数据

import numpy as np
import pandas as pd
from scipy.io import loadmat

file_names = ['0_0.mat','7_1.mat','7_2.mat','7_3.mat','14_1.mat','14_2.mat','14_3.mat','21_1.mat','21_2.mat','21_3.mat']

for file in file_names:
    # 读取MAT文件
    data = loadmat(f'matfiles\\{file}')
    print(list(data.keys()))

第二步,读取MAT文件驱动端数据

# 采用驱动端数据
data_columns = ['X097_DE_time', 'X105_DE_time', 'X118_DE_time', 'X130_DE_time', 'X169_DE_time',
                'X185_DE_time','X197_DE_time','X209_DE_time','X222_DE_time','X234_DE_time']
columns_name = ['de_normal','de_7_inner','de_7_ball','de_7_outer','de_14_inner','de_14_ball','de_14_outer','de_21_inner','de_21_ball','de_21_outer']
data_12k_10c = pd.DataFrame()
for index in range(10):
    # 读取MAT文件
    data = loadmat(f'matfiles\\{file_names[index]}')
    dataList = data[data_columns[index]].reshape(-1)
    data_12k_10c[columns_name[index]] = dataList[:119808]  # 121048  min: 121265
print(data_12k_10c.shape)
data_12k_10c

第三步,制作数据集

train_set、val_set、test_set 均为按照7:2:1划分训练集、验证集、测试集,最后保存数据

第四步,制作训练集和标签

# 制作数据集和标签
import torch

# 这些转换是为了将数据和标签从Pandas数据结构转换为PyTorch可以处理的张量,
# 以便在神经网络中进行训练和预测。

def make_data_labels(dataframe):
    '''
        参数 dataframe: 数据框
        返回 x_data: 数据集     torch.tensor
            y_label: 对应标签值  torch.tensor
    '''
    # 信号值
    x_data = dataframe.iloc[:,0:-1]
    # 标签值
    y_label = dataframe.iloc[:,-1]
    x_data = torch.tensor(x_data.values).float()
    y_label = torch.tensor(y_label.values.astype('int64')) # 指定了这些张量的数据类型为64位整数,通常用于分类任务的类别标签
    return x_data, y_label

# 加载数据
train_set = load('train_set')
val_set = load('val_set')
test_set = load('test_set')

# 制作标签
train_xdata, train_ylabel = make_data_labels(train_set)
val_xdata, val_ylabel = make_data_labels(val_set)
test_xdata, test_ylabel = make_data_labels(test_set)
# 保存数据
dump(train_xdata, 'trainX_1024_10c')
dump(val_xdata, 'valX_1024_10c')
dump(test_xdata, 'testX_1024_10c')
dump(train_ylabel, 'trainY_1024_10c')
dump(val_ylabel, 'valY_1024_10c')
dump(test_ylabel, 'testY_1024_10c')

1.2 数据加载,训练数据、测试数据分组,数据分batch

import torch
from joblib import dump, load
import torch.utils.data as Data
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# 参数与配置
torch.manual_seed(100)  # 设置随机种子,以使实验结果具有可重复性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载数据集
def dataloader(batch_size, workers=2):
    # 训练集
    train_xdata = load('trainX_1024_10c')
    train_ylabel = load('trainY_1024_10c')
    # 验证集
    val_xdata = load('valX_1024_10c')
    val_ylabel = load('valY_1024_10c')
    # 测试集
    test_xdata = load('testX_1024_10c')
    test_ylabel = load('testY_1024_10c')

    # 加载数据
    train_loader = Data.DataLoader(dataset=Data.TensorDataset(train_xdata, train_ylabel),
                                   batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    val_loader = Data.DataLoader(dataset=Data.TensorDataset(val_xdata, val_ylabel),
                                 batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    test_loader = Data.DataLoader(dataset=Data.TensorDataset(test_xdata, test_ylabel),
                                  batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    return train_loader, val_loader, test_loader

batch_size = 32
# 加载数据
train_loader, val_loader, test_loader = dataloader(batch_size)

2 CNN-2D分类模型和训练、评估

2.1 定义CNN-2d分类模型

注意:输入数据进行了堆叠 ,把一个1*1024 的序列 进行划分堆叠成形状为1 * 32 * 32, 就使输入序列的长度降下来了,(channels, seq_length, H_in)

2.2 定义模型参数

# 定义模型参数
batch_size = 32
# 先用浅层试一试
conv_arch = ((2, 32), (1, 64), (1, 128))  
input_channels = 1
num_classes = 10
model = CNN2DModel(conv_arch, num_classes, batch_size)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

2.3 模型结构

2.4 模型训练

训练结果

50个epoch,准确率将近97%,CNN-2D网络分类模型效果良好。

2.5 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载模型
model =torch.load('best_model_cnn2d.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))

# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():
    correct_test = 0
    test_loss = 0
    for test_data, test_label in test_loader:
        test_data, test_label = test_data.to(device), test_label.to(device)
        test_output = model(test_data)
        probabilities = F.softmax(test_output, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        correct_test += (predicted_labels == test_label).sum().item()
        loss = loss_function(test_output, test_label)
        test_loss += loss.item()

test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')

Test Accuracy: 0.9313  Test Loss: 0.04866932

3 CNN-1D分类模型和训练、评估

3.1 定义CNN-1d分类模型

注意:与2d模型的信号长度堆叠不同,CNN-1D模型直接在一维序列上进行卷积池化操作;形状为(batch,H_in, seq_length),利用平均池化 使CNN-1D和CNN-2D模型最后输出维度相同,保持着相近的参数量。

3.2 定义模型参数

# 定义模型参数
batch_size = 32
# 先用浅层试一试
conv_arch = ((2, 32), (1, 64), (1, 128))  
input_channels = 1
num_classes = 10
model = CNN1DModel(conv_arch, num_classes, batch_size)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

3.3 模型结构

3.4 模型训练

训练结果

50个epoch,准确率将近95%,CNN-1D网络分类模型效果良好。

3.5 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载模型
model =torch.load('best_model_cnn1d.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))

# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():
    correct_test = 0
    test_loss = 0
    for test_data, test_label in test_loader:
        test_data, test_label = test_data.to(device), test_label.to(device)
        test_output = model(test_data)
        probabilities = F.softmax(test_output, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        correct_test += (predicted_labels == test_label).sum().item()
        loss = loss_function(test_output, test_label)
        test_loss += loss.item()

test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')

Test Accuracy: 0.9185  Test Loss: 0.14493044

4 模型对比

对比CNN-2D模型 和CNN-1D模型:

模型参数量训练集准确率验证集准确率测试集准确率
CNN1D61565496.5694.6491.85
CNN2D68343098.3896.8893.13

由于CNN-2D模型参数量稍微多一点,所以模型表现得也略好一点,适当调整参数,两者模型准确率相近。但是CNN-2D推理速度要快于CNN-1D,在轴承故障数据集上,应该更考虑CNN-2D模型在堆叠后的一维信号上进行卷积池化。

注意调整参数:

  • 可以适当增加 CNN层数 和每层神经元个数,微调学习率;

  • 增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/235163.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【密码学基础】Diffie-Hellman密钥交换协议

DH介绍 Diffie-Hellman密钥协议算法是一种确保共享密钥安全穿越不安全网络的方法。 这个机制的巧妙在于需要安全通信的双方可以用这个方法确定对称密钥,然后可以用这个密钥进行加密和解密。 但是注意,这个密钥交换协议 只能用于密钥的交换,而…

java应用打包运行的4种方法

文章目录 1、方法一:打成jar部署运行2、方法二:通过自制启动器的方式运行3、方法三:使用jpackage把java和jdk一起打包4、方法四:使用GraalVM编译为原生应用4.1、使用native-image-agent(Graalvm内工具)工具来收集这些运行库信息4.…

软件无线电SDR-频谱采集python实现

sdr做的频谱采集,保存的500张频谱图,能看出来是什么东西吗?

成都工业学院Web技术基础(WEB)实验四:CSS3布局应用

写在前面 1、基于2022级计算机大类实验指导书 2、代码仅提供参考,前端变化比较大,按照要求,只能做到像,不能做到一模一样 3、图片和文字仅为示例,需要自行替换 4、如果代码不满足你的要求,请寻求其他的…

perl使用Archive::Tar模块进行文件打包

使用perl的Archive::Tar 模块打包文件夹中的指定文件,输出格式 .tar.gz 。下面是perl代码: #! /usr/bin/perl use v5.14; use File::Find; use Archive::Tar;my filesArry (); my $callback sub {push filesArry, $File::Find::name if -f; };find($c…

【Hive】启动beeline连接hive报错解决

1、解决报错2、在datagrip上连接hive 1、解决报错 刚开始一直报错&#xff1a;启动不起来 hive-site.xml需要配置hiveserver2相关的 在hive-site.xml文件中添加如下配置信息 <!-- 指定hiveserver2连接的host --> <property><name>hive.server2.thrift.bin…

如何打印富文本控件中的内容?

出于某种原因&#xff0c;人们确实对打印富文本控件中的内容感到困惑。 我并非打印方面的专家&#xff0c;但是经过对资料的研究的&#xff0c;我也算弄明白了&#xff0c;今天在此记录一下。 解决问题的关键是这个消息&#xff1a;EM_FORMATRANGE。 每次发送这个消息的时候&a…

Vue 3项目的目录结构

使用vite创建完VUE项目后&#xff0c;使用VS Code编辑器打开项目目录&#xff0c;可以看到一个默认生成的项目目录结构 下图是目录结构&#xff1a; 详细介绍.vscode&#xff1a;存放VS Code编辑器的相关配置。 node_modules&#xff1a;存放项目的各种依赖和安装的插件。…

C# WPF上位机开发(通讯协议的编写)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 作为上位机&#xff0c;它很重要的一个部分就是需要和外面的设备进行数据沟通的。很多时候&#xff0c;也就是在这个沟通的过程当中&#xff0c;上…

极狐GitLab 与 Flux 集成实现 GitOps

目录 flux 和 GitOps 极狐GitLab 与 flux 的集成 flux 命令行安装 极狐GitLab flux GitOps GitOps Demo 写在最后 flux 和 GitOps 众所周知&#xff0c;weaveworks 公司在 2017 年提出了 GitOps 这个概念&#xff0c;而 flux 是 weaveworks 开源的一款对 Kubernetes 上的…

【总结】机器学习中的15种分类算法

目录 一、机器学习中的分类算法 1.1 基础分类算法 1.2 集成分类算法 1.3 其它分类算法&#xff1a; 二、各种机器学习分类算法的优缺点 分类算法也称为模式识别&#xff0c;是一种机器学习算法&#xff0c;其主要目的是从数据中发现规律并将数据分成不同的类别。分类算法通…

golang学习笔记——go流水线示例

range与数组、切片、集合 Go 语言中 range 关键字用于 for 循环中迭代数组(array)、切片(slice)、通道(channel)或集合(map)的元素。在数组和切片中它返回元素的索引和索引对应的值&#xff0c;在集合中返回 key-value 对。 for 循环的 range 格式可以对 slice、map、数组、字…

计算机网络课程设计【Python实现】

一、网络聊天程序的设计与实现 1、实验目的 使用Socket编程&#xff0c;了解Socket通信的原理&#xff0c;会使用Socket进行简单的网络编程&#xff0c;并在此基础上编写聊天程序&#xff0c;运行服务器端和客户端&#xff0c;实现多个客户端通过服务器端进行通信。 2、总体设…

MongoDB 的复制(副本集)

本文主要介绍MongoDB的复制&#xff08;副本集&#xff09;。 目录 MongoDB的复制&#xff08;副本集&#xff09;特点搭建步骤注意事项 MongoDB的复制&#xff08;副本集&#xff09; MongoDB复制是一种提供数据冗余和高可用性的方法。复制是通过在多个节点上维护数据的副本来…

机器学习与人工智能:一场革命性的变革

机器学习与人工智能&#xff1a;一场革命性的变革 人工智能的概述什么是机器学习定义解释 数据集结构机器学习应用场景 人工智能的概述 1956年8月&#xff0c;在美国汉诺斯小镇宁静的达特茅斯学院中&#xff0c;约翰麦卡锡&#xff08;John McCarthy&#xff09;、马文闵斯基&…

科学小论文

赵州桥&#xff0c;是一座右拱桥&#xff0c;它座落于河北省石家庄市赵县城南液河之上。 赵州桥因赵县古称赵州而得名&#xff0c;当地人称之为大石桥&#xff0c;以区别于城西门外的永通桥&#xff0c;也称小石桥。 赵州桥始建于隋代&#xff0c;由匠师李春设计建造&#xff…

第一百九十九回 如何获取设备信息

文章目录 1. 概念介绍2. 使用方法3. 代码与效果3.1 示例代码3.2 运行效果 我们在上一章回中介绍了包管理相关的内容&#xff0c;本章回中将介绍如何使用url_launcher包.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我们在这里介绍url_launcher包主要用来打开…

万户 ezOFFICE 文件上传漏洞复现

0x01 产品简介 万户OA ezoffice是万户网络协同办公产品多年来一直将主要精力致力于中高端市场的一款OA协同办公软件产品,统一的基础管理平台,实现用户数据统一管理、权限统一分配、身份统一认证。统一规划门户网站群和协同办公平台,将外网信息维护、客户服务、互动交流和日…

HarmonyOS鸿蒙应用开发——HTTP网络访问与封装

文章目录 基本使用封装参考 基本使用 鸿蒙应用发起HTTP请求的基本使用&#xff0c;如下&#xff1a; 导入http模块创建httpRequest对象发起http请求&#xff0c;并处理响应结果 第一、导入http模块&#xff1a; import http from ohos.net.http第二、创建httpRequest对象&a…