Python电能质量扰动信号分类(二)基于CNN模型的一维信号分类

目录

前言

1 电能质量数据集制作与加载

1.1 导入数据

1.2 制作数据集

2 CNN-2D分类模型和训练、评估

2.1 定义CNN-2d分类模型

2.2 定义模型参数

2.3 模型结构

2.4 模型训练

2.5 模型评估

3 CNN-1D分类模型和训练、评估

3.1 定义CNN-1d分类模型

3.2 定义模型参数

3.3 模型结构

3.4 模型训练

3.5 模型评估

4 模型对比


往期精彩内容:

电能质量扰动信号数据介绍与分类-Python实现-CSDN博客

Python电能质量扰动信号分类(一)基于LSTM模型的一维信号分类-CSDN博客

轴承故障诊断分类模型全家桶-最全教程-CSDN博客

前言

本文基于Python仿真的电能质量扰动信号,先经过数据预处理进行数据集的制作和加载,然后通过Pytorch实现CNN模型一维卷积和二维卷积对扰动信号的分类。Python仿真电能质量扰动信号的详细介绍可以参考下文(文末附10分类数据集):

电能质量扰动信号数据介绍与分类-Python实现-CSDN博客

部分扰动信号类型波形图如下所示:

1 电能质量数据集制作与加载

1.1 导入数据

在参考IEEE Std1159-2019电能质量检测标准与相关文献的基础上构建了扰动信号的模型,生成包括正常信号在内的10中单一信号和多种复合扰动信号。参考之前的文章,进行扰动信号10分类的预处理:

第一步,按照公式模型生成单一信号

单一扰动信号可视化:

第二步,导入十分类数据

import pandas as pd
import numpy as np

# 样本时长0.2s  样本步长1024  每个信号生成500个样本  噪声0DB  
window_step = 1024
samples = 500
noise = 0
split_rate = [0.7, 0.2, 0.1]  # 训练集、验证集、测试集划分比例

# 读取已处理的 CSV 文件
dataframe_10c = pd.read_csv('PDQ_10c_Clasiffy_data.csv' )
dataframe_10c.shape

1.2 制作数据集

第一步,定义制作数据集函数

第二步,制作数据集与分类标签

from joblib import dump, load
# 生成数据
train_dataframe, val_dataframe, test_dataframe = make_data(dataframe_10c, split_rate)
# 制作标签
train_xdata, train_ylabel = make_data_labels(train_dataframe)
val_xdata, val_ylabel = make_data_labels(val_dataframe)
test_xdata, test_ylabel = make_data_labels(test_dataframe)
# 保存数据
dump(train_xdata, 'TrainX_1024_0DB_10c')
dump(val_xdata, 'ValX_1024_0DB_10c')
dump(test_xdata, 'TestX_1024_0DB_10c')
dump(train_ylabel, 'TrainY_1024_0DB_10c')
dump(val_ylabel, 'ValY_1024_0DB_10c')
dump(test_ylabel, 'TestY_1024_0DB_10c')

2 CNN-2D分类模型和训练、评估

2.1 定义CNN-2d分类模型

2.2 定义模型参数

# 定义模型参数
batch_size = 32
# 先用浅层试一试
conv_arch = ((2, 32), (1, 64), (1, 128))  
input_channels = 1
num_classes = 10
model = CNN2DModel(conv_arch, num_classes, batch_size)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

2.3 模型结构

2.4 模型训练

训练结果

50个epoch,准确率将近97%,CNN-2D网络分类模型效果良好。

2.5 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载模型
model =torch.load('best_model_cnn2d.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))

# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():
    correct_test = 0
    test_loss = 0
    for test_data, test_label in test_loader:
        test_data, test_label = test_data.to(device), test_label.to(device)
        test_output = model(test_data)
        probabilities = F.softmax(test_output, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        correct_test += (predicted_labels == test_label).sum().item()
        loss = loss_function(test_output, test_label)
        test_loss += loss.item()

test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')

Test Accuracy: 0.9313  Test Loss: 0.04866932

3 CNN-1D分类模型和训练、评估

3.1 定义CNN-1d分类模型

注意:与2d模型的信号长度堆叠不同,CNN-1D模型直接在一维序列上进行卷积池化操作;形状为(batch,H_in, seq_length),利用平均池化 使CNN-1D和CNN-2D模型最后输出维度相同,保持着相近的参数量。

3.2 定义模型参数

# 定义模型参数
batch_size = 32
# 先用浅层试一试
conv_arch = ((2, 32), (1, 64), (1, 128))  
input_channels = 1
num_classes = 10
model = CNN1DModel(conv_arch, num_classes, batch_size)  
# 定义损失函数和优化函数
model = model.to(device)
loss_function = nn.CrossEntropyLoss(reduction='sum')  # loss
learn_rate = 0.0003
optimizer = torch.optim.Adam(model.parameters(), learn_rate)  # 优化器

3.3 模型结构

3.4 模型训练

训练结果

100个epoch,准确率将近95%,CNN-1D网络分类模型效果良好。

3.5 模型评估

# 模型 测试集 验证  
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 有GPU先用GPU训练

# 加载模型
model =torch.load('best_model_cnn1d.pt')
# model = torch.load('best_model_cnn2d.pt', map_location=torch.device('cpu'))

# 将模型设置为评估模式
model.eval()
# 使用测试集数据进行推断
with torch.no_grad():
    correct_test = 0
    test_loss = 0
    for test_data, test_label in test_loader:
        test_data, test_label = test_data.to(device), test_label.to(device)
        test_output = model(test_data)
        probabilities = F.softmax(test_output, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)
        correct_test += (predicted_labels == test_label).sum().item()
        loss = loss_function(test_output, test_label)
        test_loss += loss.item()

test_accuracy = correct_test / len(test_loader.dataset)
test_loss = test_loss / len(test_loader.dataset)
print(f'Test Accuracy: {test_accuracy:4.4f}  Test Loss: {test_loss:10.8f}')

Test Accuracy: 0.9185  Test Loss: 0.14493044

4 模型对比

对比CNN-2D模型 和CNN-1D模型:

模型参数量训练集准确率验证集准确率测试集准确率
CNN1D61565496.5694.6491.85
CNN2D68343098.3896.8893.13

由于CNN-2D模型参数量稍微多一点,所以模型表现得也略好一点,适当调整参数,两者模型准确率相近。但是CNN-2D推理速度要快于CNN-1D,在电能质量扰动信号数据集上,应该更考虑CNN-2D模型在堆叠后的一维信号上进行卷积池化。

注意调整参数:

  • 可以适当增加 CNN层数 和每层神经元个数,微调学习率;

  • 增加更多的 epoch (注意防止过拟合)

  • 可以改变一维信号堆叠的形状(设置合适的长度和维度)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/266509.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

论文阅读——BLIP-2

BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models 1 模型 在预训练视觉模型和预训练大语言模型中间架起了一座桥梁。两阶段训练,视觉文本表示和视觉到语言生成学习。 Q-Former由两个转换器子模块组成&am…

六大开源 OA 办公系统

OA,即Office Automation的缩写,意思是办公自动化、协同办公。在现代办公环境中,办公自动化已经成为了必不可少的一部分,它可以代替办公人员传统的手动部分或重复性业务活动,优质而高效地处理办公事务和业务信息,实现对…

Openwrt AP 发射 WiFi 信号

问题 想一次把 OpenWrt 路由器 wifi 问题给解决,完全取代路由器。 使用 倍控的 N5105 设备,有 mPCIe 接口,使用了 intel AX200 无线网卡,支持 2.4G 与 5G。 设置步骤 OpenWrt 镜像 第一次使用的镜像不支持 wifi,在…

模式识别与机器学习(八):决策树

1.原理 决策树(Decision Tree),它是一种以树形数据结构来展示决策规则和分类结果的模型,作为一种归纳学习算法,其重点是将看似无序、杂乱的已知数据,通过某种技术手段将它们转化成可以预测未知数据的树状模…

论文笔记--Learning Political Polarization on Social Media Using Neural Networks

论文笔记--Learning Political Polarization on Social Media Using Neural Networks 1. 文章简介2. 文章概括3. 相关工作4. 文章重点技术4.1 Collection of posts4.1.1 数据下载4.1.2 数据预处理4.1.3 统计显著性分析 4.2 Classification of Posts4.3 Polarization of users 5…

自然语言处理(NLP):理解语言,赋能未来

目录 前言1 什么是NLP2 NLP的用途3 发展历史4 NLP的基本任务4.1 词性标注(Part-of-Speech Tagging)4.2 命名实体识别(Named Entity Recognition)4.3 共指消解(Co-reference Resolution)4.4 依存关系分析&am…

1855_emacs_compnay的使用探索

Grey 全部学习内容汇总: GitHub - GreyZhang/editors_skills: Summary for some common editor skills I used. 1855_emacs_compnay的使用探索 company其实是一个老伙伴了,之前我emacs中体验提升的主力插件之一。主要是用来做各种场景下的补全&#x…

物联网产品设计,聊聊设备OTA的升级

物联网产品设计部分的OTA设备固件是一个非常重要的部分,能够实现升级用户服务、保障系统安全等功能。 在迅速变化和发展的物联网市场,新的产品需求不断涌现,因此对于智能硬件设备的更新需求就变得空前高涨,设备不再像传统设备一样…

simulinkveristandlabview联合仿真——模型导入搭建人机界面

目录 1.软件版本 2.搭建simulink仿真模型 编译错误 3.导入veristand并建立工程 4.veristand导入labview labview显示veristand工程数据 labview设置veristand工程数据 运行labview工程 1.软件版本 matlab2020a,veristand2020 R4,labview2020 SP…

7种常见的网络安全设备及其功能

网络安全设备在现代网络环境中起着至关重要的作用,帮助保护个人和组织免受恶意攻击。本文将介绍7种常见的网络安全设备,包括防火墙、入侵检测系统、反病毒软件、数据加密设备、虚拟私人网络、安全信息和事件管理系统以及网络访问控制设备,并详…

阅读笔记-A Cluster Separation Measure

A Cluster Separation Measure(一种聚类分离测度) 1.这篇论文要解决什么问题?要验证一个什么科学假设? 问题是确定数据中聚类的适当数量,解决这种问题的两种方法都取决于确定指数中相对较大的变化,而不是…

将PPT的图保持高分辨率导入到Word / WPS中

1、将PPT中画好的图组合在一起,选择组合后的图复制(Ctrlc) 2、在Word中,选中左上角的粘贴选项--->选择性粘贴 WPS选择元文件 / Word选择增强型图元文件 这样放大也不模糊了

Gateway API

Gateway API 目录 原文链接 https://onedayxyy.cn/docs/GatewayAPI 本节实战 实战名称🚩 实战:Gateway API在istio里的安装及测试-2023.12.23(测试失败) 前言 Gateway API 是由 SIG-NETWORK 社区管理的开源项目,项目地址:http…

【一起学Rust | 框架篇 | Tauri2.0框架】Tauri2.0环境搭建与项目创建

文章目录 前言一、搭建 Tauri 2.0 开发环境二、创建 Tauri 2.0 项目1.创建项目2.安装依赖4. 编译运行 三、设置开发环境四、项目结构 前言 Tauri在Rust圈内成名已久,凭借Rust的可靠性,使用系统原生的Webview构建更小的App 以及开发人员可以灵活的使用各…

阿里云 ARMS 应用监控重磅支持 Java 21

作者:牧思 & 山猎 前言 今年的 9 月 19 日,作为最新的 LTS (Long Term Support) Java 版本,Java 21 正式 GA,带来了不少重量级的更新,详情请参考 The Arrival of Java 21 [ 1] 。虽然目前 Java 11 和 Java 17 都…

AI创作系统ChatGPT系统源码,支持Midjourney绘画,GPT语音对话+DALL-E3文生图

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Ch…

【RabbitMQ】RabbitMQ详解(一)

RabbitMQ详解 RabbitMQ介绍四大核心概念RabbitMQ 入门案例Hello RabbitMQ生产者消费者 Work Queues轮询消费 RabbitMQ消息应答与发布自动应答手动应答的方法肯定确认应答否定确认应答拒绝处理该消息恢复到消息队列Multiple的解释消息重新入队手动应答案例 RabbitMQ持久化队列持…

面向对象设计与分析40讲(12)简单工厂方法模式

文章目录 定义示例优缺点 定义 简单工厂模式是一种创建型模式,用于根据客户端的需求创建对象实例,所谓的需求反映到编程语言里就是传入的参数。 简单工厂模式包括三个主要部分: 工厂类(Simple Factory):…

python dash学习2

代码 内有说明: from dash import Dash, html, dcc, callback, Output, Input import plotly.express as px import pandas as pd# 从 Plotly 数据集中读取数据 df pd.read_csv(https://raw.githubusercontent.com/plotly/datasets/master/gapminder_unfiltered.c…

STM32实现流水灯

led.c #include"led.h"void Led_Init(void) {GPIO_InitTypeDef GPIO_VALUE; //???RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOC,ENABLE);//???GPIO_VALUE.GPIO_ModeGPIO_Mode_Out_PP;//???? ????GPIO_VALUE.GPIO_PinGPIO_Pin_1|GPIO_Pin_2|GPIO_Pin_…