【信号处理】基于CNN自编码器的心电信号异常检测识别(tensorflow)

关于

本项目主要实现卷积自编码器对于异常心电ECG信号的检测和识别,属于无监督学习中的生理信号检测的典型方法之一。

工具

 

方法实现

读取心电信号
normal_df = pd.read_csv("/heartbeat/ptbdb_normal.csv").iloc[:, :-1]
anomaly_df = pd.read_csv("/heartbeat/ptbdb_abnormal.csv").iloc[:, :-1]
normal_df.head()

信号可视化

def plot_sample(normal, anomaly):
    index = np.random.randint(0, len(normal_df), 2)
    
    fig, ax = plt.subplots(1, 2, sharey=True, figsize=(10, 4))
    ax[0].plot(normal.iloc[index[0], :].values, label=f"Case {index[0]}")
    ax[0].plot(normal.iloc[index[1], :].values, label=f"Case {index[1]}")
    ax[0].legend(shadow=True, frameon=True, facecolor="inherit", loc=1, fontsize=9)
    ax[0].set_title("Normal")
    
    ax[1].plot(anomaly.iloc[index[0], :].values, label=f"Case {index[0]}")
    ax[1].plot(anomaly.iloc[index[1], :].values, label=f"Case {index[1]}")
    ax[1].legend(shadow=True, frameon=True, facecolor="inherit", loc=1, fontsize=9)
    ax[1].set_title("Anomaly")
    
    plt.tight_layout()
    plt.show()


plot_sample(normal_df, anomaly_df)

 

 信号均值计算及可视化
def plot_smoothed_mean(data, class_name = "normal", step_size=5, ax=None):
    df = pd.DataFrame(data)
    roll_df = df.rolling(step_size)
    smoothed_mean = roll_df.mean().dropna().reset_index(drop=True)
    smoothed_std = roll_df.std().dropna().reset_index(drop=True)
    margin = 3*smoothed_std
    lower_bound = (smoothed_mean - margin).values.flatten()
    upper_bound = (smoothed_mean + margin).values.flatten()

    ax.plot(smoothed_mean.index, smoothed_mean)
    ax.fill_between(smoothed_mean.index, lower_bound, y2=upper_bound, alpha=0.3, color="red")
    ax.set_title(class_name, fontsize=9)

fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
axes = axes.flatten()
for i, label in enumerate(CLASS_NAMES, start=1):
    data_group = df.groupby("target")
    data = data_group.get_group(label).mean(axis=0, numeric_only=True).to_numpy()
    plot_smoothed_mean(data, class_name=label, step_size=20, ax=axes[i-1])
fig.suptitle("Plot of smoothed mean for each class", y=0.95, weight="bold")
plt.tight_layout()

 训练/测试数据划分
normal_df.drop("target", axis=1, errors="ignore", inplace=True)
normal = normal_df.to_numpy()
anomaly_df.drop("target", axis=1, errors="ignore", inplace=True)
anomaly = anomaly_df.to_numpy()

X_train, X_test = train_test_split(normal, test_size=0.15, random_state=45, shuffle=True)
print(f"Train shape: {X_train.shape}, Test shape: {X_test.shape}, anomaly shape: {anomaly.shape}")
搭建自编码器
class AutoEncoder(Model):
    def __init__(self, input_dim, latent_dim):
        super(AutoEncoder, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        self.encoder = tf.keras.Sequential([
            layers.Input(shape=(input_dim,)),
            layers.Reshape((input_dim, 1)),  # Reshape to 3D for Conv1D
            layers.Conv1D(128, 3, strides=1, activation='relu', padding="same"),
            layers.BatchNormalization(),
            layers.MaxPooling1D(2, padding="same"),
            layers.Conv1D(128, 3, strides=1, activation='relu', padding="same"),
            layers.BatchNormalization(),
            layers.MaxPooling1D(2, padding="same"),
            layers.Conv1D(latent_dim, 3, strides=1, activation='relu', padding="same"),
            layers.BatchNormalization(),
            layers.MaxPooling1D(2, padding="same"),
        ])
        # Previously, I was using UpSampling. I am trying Transposed Convolution this time around.
        self.decoder = tf.keras.Sequential([
            layers.Conv1DTranspose(latent_dim, 3, strides=1, activation='relu', padding="same"),
#             layers.UpSampling1D(2),
            layers.BatchNormalization(),
            layers.Conv1DTranspose(128, 3, strides=1, activation='relu', padding="same"),
#             layers.UpSampling1D(2),
            layers.BatchNormalization(),
            layers.Conv1DTranspose(128, 3, strides=1, activation='relu', padding="same"),
#             layers.UpSampling1D(2),
            layers.BatchNormalization(),
            layers.Flatten(),
            layers.Dense(input_dim)
        ])

    def call(self, X):
        encoded = self.encoder(X)
        decoded = self.decoder(encoded)
        return decoded


input_dim = X_train.shape[-1]
latent_dim = 32

model = AutoEncoder(input_dim, latent_dim)
model.build((None, input_dim))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss="mae")
model.summary()
模型训练
epochs = 100
batch_size = 128
early_stopping = EarlyStopping(patience=10, min_delta=1e-3, monitor="val_loss", restore_best_weights=True)


history = model.fit(X_train, X_train, epochs=epochs, batch_size=batch_size,
                    validation_split=0.1, callbacks=[early_stopping])
训练可视化
plt.plot(history.history['loss'], label="Training loss")
plt.plot(history.history['val_loss'], label="Validation loss", ls="--")
plt.legend(shadow=True, frameon=True, facecolor="inherit", loc="best", fontsize=9)
plt.title("Training loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.show()

 

信号重建可视化
fig, axes = plt.subplots(2, 5, sharey=True, sharex=True, figsize=(12, 6))
random_indexes = np.random.randint(0, len(X_train), size=5)

for i, idx in enumerate(random_indexes):
    data = X_train[[idx]]
    plot_examples(model, data, ax=axes[0, i], title="Normal")

for i, idx in enumerate(random_indexes):
    data = anomaly[[idx]]
    plot_examples(model, data, ax=axes[1, i], title="anomaly")

plt.tight_layout()
fig.suptitle("Sample plots (Actual vs Reconstructed by the CNN autoencoder)", y=1.04, weight="bold")
fig.savefig("autoencoder.png")
plt.show()

计算重建MAE误差
train_mae = model.evaluate(X_train, X_train, verbose=0)
test_mae = model.evaluate(X_test, X_test, verbose=0)
anomaly_mae = model.evaluate(anomaly_df, anomaly_df, verbose=0)

print("Training dataset error: ", train_mae)
print("Testing dataset error: ", test_mae)
print("Anormaly dataset error: ", anomaly_mae)

 异常检测阈值选取

MAE误差阈值=正常数据重建MAE均值+正常数据重建MAE标准差,此阈值可以用来直接检测某信号为正常信号还是异常心电信号。

def predict(model, X):
    pred = model.predict(X, verbose=False)
    loss = mae(pred, X)
    return pred, loss


_, train_loss = predict(model, X_train)
_, test_loss = predict(model, X_test)
_, anomaly_loss = predict(model, anomaly)
threshold = np.mean(train_loss) + np.std(train_loss) # Setting threshold for distinguish normal data from anomalous data

bins = 40
plt.figure(figsize=(9, 5), dpi=100)
sns.histplot(np.clip(train_loss, 0, 0.5), bins=bins, kde=True, label="Train Normal")
sns.histplot(np.clip(test_loss, 0, 0.5), bins=bins, kde=True, label="Test Normal")
sns.histplot(np.clip(anomaly_loss, 0, 0.5), bins=bins, kde=True, label="anomaly")

ax = plt.gca()  # Get the current Axes
ylim = ax.get_ylim()
plt.vlines(threshold, 0, ylim[-1], color="k", ls="--")
plt.annotate(f"Threshold: {threshold:.3f}", xy=(threshold, ylim[-1]), xytext=(threshold+0.009, ylim[-1]),
             arrowprops=dict(facecolor='black', shrink=0.05), fontsize=9)
plt.legend(shadow=True, frameon=True, facecolor="inherit", loc="best", fontsize=9)
plt.show()

模型评估
plot_confusion_matrix(model, X_train, X_test, anomaly, threshold=threshold)

ytrue, ypred = prepare_labels(model, X_train, X_test, anomaly, threshold=threshold)
print(classification_report(ytrue, ypred, target_names=CLASS_NAMES))

 

代码获取

相关项目开发和问题,欢迎后台沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562802.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MultiHeadAttention在Tensorflow中的实现原理

前言 通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。 一、MultiHeadAttention的本质内涵 1.Self_Atention机制 MultiHeadAttention是Self_Atention的多头堆嵌,有必要对Self_Atention机制进行一次深入浅出的理解,这…

websocket聊天的功能

第一步 安装相关依赖: node需要安装: npm i socket.io 第二步 前端cdn引入socket 第三步 编写服务端的代码 import http from node:http‘import {Server} from socket.ioconst server http.createServer()const io new Server(server,{cors:true …

Drive Scope for Mac:硬盘健康监测分析工具

Drive Scope for Mac是一款专为Mac用户设计的硬盘健康监测与分析工具,致力于保障用户的数据安全。这款软件功能强大且操作简便,能够实时检测硬盘的各项指标,帮助用户及时发现并解决潜在问题。 Drive Scope for Mac 1.2.23注册激活版下载 Driv…

配置 rust国内源

rust crate.io 配置国内源(cargo 国内源) warning: spurious network error (2 tries remainin..._warning: spurious network error (3 tries remaining-CSDN博客

政安晨:【Keras机器学习示例演绎】(七)—— 利用 NeRF 进行 3D 体积渲染

目录 简介 设置 下载并加载数据 NeRF 模型 训练 可视化训练步骤 推理 渲染三维场景 可视化视频 结论 政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras机器学习实战 希望政安晨的博客能够对您有所裨益&#xff0…

Ubuntu无法安装向日癸15.2.0.63062_amd64.deb最新版

Ubuntu安装向日葵远程控制 安装包下载 安装方式 方式一:运行安装包安装 方式二:终端命令安装 通过以下教程可以快速的安装向日葵远程控制,本教程适用于Ubuntu18.04/20.04/22.04 安装包下载 进入向日葵远程控制下载官网下载向日葵远程控制Lin…

「 网络安全常用术语解读 」什么是0day、1day、nday漏洞

1. 引言 漏洞攻击的时间窗口被称为漏洞窗口(window of vulnerability)。一般来说,漏洞窗口持续的时间越长,攻击者可以利用漏洞进行攻击的可能性就越大。 2. 0day 漏洞 0day 漏洞,又被称为"零日漏洞"&…

软件杯 深度学习实现行人重识别 - python opencv yolo Reid

文章目录 0 前言1 课题背景2 效果展示3 行人检测4 行人重识别5 其他工具6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 **基于深度学习的行人重识别算法研究与实现 ** 该项目较为新颖,适合作为竞赛课题方向&#xff0c…

了解网卡、光猫、路由器

了解网卡、光猫、路由器 一、网卡二、光猫三、路由器四、光猫和路由器的联系和区别五、家庭正常上网的简单流程六、企业正常上网的简单流程 一、网卡 网卡:用来允许计算机在计算机网络上进行通讯的计算机硬件 一般来说,笔记本都有两种网卡,有…

Flink窗口机制

1.窗口的概念 时间是为窗口服务的。窗口是什么?为什么会有窗口呢? (1)Flink要处理的数据,一般是从Kafka过来的流式数据,如果只是单纯地统计流的数据量,是没办法统计的。 (2&#xff…

Linux-软件安装--jdk安装

jdk安装 前言1、软件安装方式二进制发布包安装rpm安装yum安装源码编译安装 2、安装jdk2.1、使用finalShell自带的上传工具将jdk的二进制发布包上传到Linux2.2、解压安装包2.3、配置环境变量![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/61ba9750e2e34638a39575c5…

【OpenHarmony-NDK技术】简单将cJson移植到OpenHarmony中,并在c层修改参数值再返回json

1、cJson的简单介绍 cJson - github网址 概述 一般使用cJson是,需要将json文本转化为json对象–编码,将json对象转化为json文本–解析。 git clone https://github.com/DaveGamble/cJSON.git 后留意cJSON.h和cJSON.h两个文件。 1、cJson的介绍 cJso…

决策树分类任务实战(python 代码详解)

目录 一、导入库、数据集、并划分训练集和测试集 二、参数调优 (一)第一种调参方法:for循环 (1)单参数优化 ①单参数优化(无K折交叉验证) ②单参数K折交叉验证 优化 (2)多参数优化 ①多参数优化(无K折交叉验证) 参数介绍: ②多参数K折交叉验证…

ruoyi-vue前端的一些自定义插件介绍

文章目录 自定义列表$tab对象打开页签关闭页签刷新页签 $modal对象提供成功、警告和错误等反馈信息(无需点击确认)提供成功、警告和错误等提示信息(类似于alert,需要点确认)提供成功、警告和错误等提示信息&#xff08…

常见现代卷积神经网络(ResNet, DenseNet)(Pytorch 11)

一 批量规范化(batch normalization) 训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。批量规范化(batch normalization)是一种流行且有效的技术,可持续加速深层网络的收敛速度。 …

数据结构与算法解题-20240421

数据结构与算法解题-20240421 一、278. 第一个错误的版本二、541. 反转字符串 II三、右旋字符串四、替换数字五、977.有序数组的平方 一、278. 第一个错误的版本 简单 你是产品经理,目前正在带领一个团队开发新的产品。不幸的是,你的产品的最新版本没有…

数据可视化(七):Pandas香港酒店数据高级分析,涉及相关系数,协方差,数据离散化,透视表等精美可视化展示

Tips:"分享是快乐的源泉💧,在我的博客里,不仅有知识的海洋🌊,还有满满的正能量加持💪,快来和我一起分享这份快乐吧😊! 喜欢我的博客的话,记得…

centos7搭建git服务器

1.centos7安装git yum install -y git yum install -y git-daemon 2.初始化空目录仓库 mkdir /usr/local/git mkdir /usr/local/git/projects mkdir /usr/local/git/projects/test-projects.git cd test-projects.git git --bare init 3.修改目录权限 cd .. chmod 775 tes…

力扣283. 移动零

Problem: 283. 移动零 文章目录 题目描述思路复杂度Code 题目描述 思路 1.定义一个int类型变量index初始化为0; 2.遍历nums当当前的元素nums[i]不为0时使nums[i]赋值给nums[index]; 3.从index开始将nums中置对应位置的元素设为0; 复杂度 时间…

第三届 SWCTF-Web 部分 WP

写在前面 题目主要涉及的是前端 php 内容知识,仅以本篇博客记录自己 Web 出题的奇思妙想。 Copyright © [2024] [Myon⁶]. All rights reserved. 目录 1、HTTP 2、再见了晚星 3、myon123_easy_php 4、baby_P0P 5、LOGIN!!! 1、HTTP 首页文件默认就是 ind…