基于自编码器的心电信号异常检测(Pytorch)

代码较为简单,很容易读懂。

# Importing necessary libraries for TensorFlow, pandas, numpy, and matplotlib
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import copy


# Importing the PyTorch library
import torch


# Importing additional libraries for data manipulation, visualization, and machine learning
import copy
import seaborn as sns
from pylab import rcParams
from matplotlib import rc
from sklearn.model_selection import train_test_split


# Importing PyTorch modules for neural network implementation
from torch import nn, optim
import torch.nn.functional as F
import torch.nn as nn


# Ignoring warnings to enhance code cleanliness
import warnings
warnings.filterwarnings('ignore')
df = pd.read_csv('http://storage.googleapis.com/download.tensorflow.org/data/ecg.csv',header=None)
df.head().T

df.describe()

df.isna().sum()
0      0
1      0
2      0
3      0
4      0
      ..
136    0
137    0
138    0
139    0
140    0
Length: 141, dtype: int64
df.dtypes
0      float64
1      float64
2      float64
3      float64
4      float64
        ...   
136    float64
137    float64
138    float64
139    float64
140    float64
Length: 141, dtype: object
new_columns = list(df.columns)
new_columns[-1] = 'target'
df.columns = new_columns
df.target.value_counts()
1.0    2919
0.0    2079
Name: target, dtype: int64
value_counts = df['target'].value_counts()


# Plotting
plt.figure(figsize=(8, 6))
value_counts.plot(kind='bar', color='skyblue')
plt.title('Value Counts of Target Column')
plt.xlabel('Target Values')
plt.ylabel('Count')


# Display the count values on top of the bars
for i, count in enumerate(value_counts):
    plt.text(i, count + 0.1, str(count), ha='center', va='bottom')


plt.show()

classes = df.target.unique()


def plot_ecg(data, class_name, ax, n_steps=10):
    # Convert data to a DataFrame
    time_series_df = pd.DataFrame(data)


    # Apply a moving average for smoothing
    smooth_data = time_series_df.rolling(window=n_steps, min_periods=1).mean()


    # Calculate upper and lower bounds for confidence interval
    deviation = time_series_df.rolling(window=n_steps, min_periods=1).std()
    upper_bound = smooth_data + deviation
    lower_bound = smooth_data - deviation


    # Plot the smoothed data
    ax.plot(smooth_data, color='black', linewidth=2)


    # Plot the confidence interval
    ax.fill_between(time_series_df.index, lower_bound[0], upper_bound[0], color='black', alpha=0.2)


    # Set the title
    ax.set_title(class_name)
# Plotting setup
fig, axs = plt.subplots(
    nrows=len(classes) // 3 + 1,
    ncols=3,
    sharey=True,
    figsize=(14, 8)
)


# Plot for each class
for i, cls in enumerate(classes):
    ax = axs.flat[i]
    data = df[df.target == cls].drop(labels='target', axis=1).mean(axis=0).to_numpy()
    plot_ecg(data, cls, ax)  # Using 'cls' directly as class name


# Adjust layout and remove extra axes
fig.delaxes(axs.flat[-1])
fig.tight_layout()


plt.show()

normal_df = df[df.target == 1].drop(labels='target', axis=1)
normal_df.shape
(2919, 140)
anomaly_df = df[df.target != 1].drop(labels='target', axis=1)
anomaly_df.shape
(2079, 140)
# Splitting the Dataset


# Initial Train-Validation Split:
# The dataset 'normal_df' is divided into training and validation sets.
# 15% of the data is allocated to the validation set.
# The use of 'random_state=42' ensures reproducibility.


train_df, val_df = train_test_split(
  normal_df,
  test_size=0.15,
  random_state=42
)


# Further Splitting for Validation and Test:
# The validation set obtained in the previous step is further split into validation and test sets.
# 33% of the validation set is allocated to the test set.
# The same 'random_state=42' is used for consistency in randomization.


val_df, test_df = train_test_split(
  val_df,
  test_size=0.30,
  random_state=42
)
# Function to Create a Dataset
def create_dataset(df):
    # Convert DataFrame to a list of sequences, each represented as a list of floats
    sequences = df.astype(np.float32).to_numpy().tolist()


    # Convert sequences to PyTorch tensors, each with shape (sequence_length, 1, num_features)
    dataset = [torch.tensor(s).unsqueeze(1).float() for s in sequences]


    # Extract dimensions of the dataset
    n_seq, seq_len, n_features = torch.stack(dataset).shape


    # Return the dataset, sequence length, and number of features
    return dataset, seq_len, n_features
# Create the training dataset from train_df
train_dataset, seq_len, n_features = create_dataset(train_df)


# Create the validation dataset from val_df
val_dataset, _, _ = create_dataset(val_df)


# Create the test dataset for normal cases from test_df
test_normal_dataset, _, _ = create_dataset(test_df)


# Create the test dataset for anomalous cases from anomaly_df
test_anomaly_dataset, _, _ = create_dataset(anomaly_df)

Implementation of LSTM-Based Autoencoder for ECG Anomaly Detection

class Encoder(nn.Module):


  def __init__(self, seq_len, n_features, embedding_dim=64):
    super(Encoder, self).__init__()


    self.seq_len, self.n_features = seq_len, n_features
    self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim


    self.rnn1 = nn.LSTM(
      input_size=n_features,
      hidden_size=self.hidden_dim,
      num_layers=1,
      batch_first=True
    )


    self.rnn2 = nn.LSTM(
      input_size=self.hidden_dim,
      hidden_size=embedding_dim,
      num_layers=1,
      batch_first=True
    )


  def forward(self, x):
    x = x.reshape((1, self.seq_len, self.n_features))


    x, (_, _) = self.rnn1(x)
    x, (hidden_n, _) = self.rnn2(x)


    return hidden_n.reshape((self.n_features, self.embedding_dim))
class Decoder(nn.Module):


  def __init__(self, seq_len, input_dim=64, n_features=1):
    super(Decoder, self).__init__()


    self.seq_len, self.input_dim = seq_len, input_dim
    self.hidden_dim, self.n_features = 2 * input_dim, n_features


    self.rnn1 = nn.LSTM(
      input_size=input_dim,
      hidden_size=input_dim,
      num_layers=1,
      batch_first=True
    )


    self.rnn2 = nn.LSTM(
      input_size=input_dim,
      hidden_size=self.hidden_dim,
      num_layers=1,
      batch_first=True
    )


    self.output_layer = nn.Linear(self.hidden_dim, n_features)


  def forward(self, x):
    x = x.repeat(self.seq_len, self.n_features)
    x = x.reshape((self.n_features, self.seq_len, self.input_dim))


    x, (hidden_n, cell_n) = self.rnn1(x)
    x, (hidden_n, cell_n) = self.rnn2(x)
    x = x.reshape((self.seq_len, self.hidden_dim))


    return self.output_layer(x)
class Autoencoder(nn.Module):


  def __init__(self, seq_len, n_features, embedding_dim=64):
    super(Autoencoder, self).__init__()


    self.encoder = Encoder(seq_len, n_features, embedding_dim).to(device)
    self.decoder = Decoder(seq_len, embedding_dim, n_features).to(device)


  def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)


    return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Autoencoder(seq_len, n_features, 128)
model = model.to(device)

Training and Visualization of ECG Autoencoder Model

def plot_input_reconstruction(model, dataset, epoch):
    model = model.eval()


    plt.figure(figsize=(10, 5))


    # Take the first sequence from the dataset
    seq_true = dataset[0].to(device)
    seq_pred = model(seq_true)


    with torch.no_grad():
        # Squeeze the sequences to ensure they are 1-dimensional
        input_sequence = seq_true.squeeze().cpu().numpy()
        reconstruction_sequence = seq_pred.squeeze().cpu().numpy()


        # Check the shape after squeezing
        if input_sequence.ndim != 1 or reconstruction_sequence.ndim != 1:
            raise ValueError("Input and reconstruction sequences must be 1-dimensional after squeezing.")


        # Plotting the sequences
        plt.plot(input_sequence, label='Input Sequence', color='black')
        plt.plot(reconstruction_sequence, label='Reconstruction Sequence', color='red')
        plt.fill_between(range(len(input_sequence)), input_sequence, reconstruction_sequence, color='gray', alpha=0.5)


        plt.title(f'Input vs Reconstruction - Epoch {epoch}')
        plt.legend()
        plt.show()






import torch
import numpy as np
import copy


def train_model(model, train_dataset, val_dataset, n_epochs, save_path):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    criterion = torch.nn.L1Loss(reduction='sum').to(device)
    history = {'train': [], 'val': []}


    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')


    for epoch in range(1, n_epochs + 1):
        model.train()


        train_losses = []
        for seq_true in train_dataset:
            optimizer.zero_grad()


            seq_true = seq_true.to(device)
            seq_pred = model(seq_true)


            loss = criterion(seq_pred, seq_true)


            loss.backward()
            optimizer.step()


            train_losses.append(loss.item())


        val_losses = []
        model.eval()
        with torch.no_grad():
            for seq_true in val_dataset:
                seq_true = seq_true.to(device)
                seq_pred = model(seq_true)


                loss = criterion(seq_pred, seq_true)
                val_losses.append(loss.item())


        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)


        history['train'].append(train_loss)
        history['val'].append(val_loss)


        if val_loss < best_loss:
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            # Save the best model weights
            print("Saving best model")
            torch.save(model.state_dict(), save_path)


        print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')


        if epoch == 1 or epoch % 5 == 0:
            plot_input_reconstruction(model, val_dataset, epoch)


    # Load the best model weights before returning
    model.load_state_dict(best_model_wts)
    return model.eval(), history
save_path = 'best_model.pth'  # Replace with your actual path
model, history = train_model(model, train_dataset, val_dataset, 100, save_path)

ax = plt.figure().gca()


ax.plot(history['train'],label='Train Loss', color='black')
ax.plot(history['val'],label='Val Loss', color='red')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'test'])
plt.title('Loss over training epochs')
plt.show();

ECG Anomaly Detection Model Evaluation and Visualization

model = Autoencoder(seq_len, n_features, 128)


model.load_state_dict(torch.load('best_model.pth'))


model = model.to(device)
model.eval()
Autoencoder(
  (encoder): Encoder(
    (rnn1): LSTM(1, 256, batch_first=True)
    (rnn2): LSTM(256, 128, batch_first=True)
  )
  (decoder): Decoder(
    (rnn1): LSTM(128, 128, batch_first=True)
    (rnn2): LSTM(128, 256, batch_first=True)
    (output_layer): Linear(in_features=256, out_features=1, bias=True)
  )
)
def predict(model, dataset):
  predictions, losses = [], []
  criterion = nn.L1Loss(reduction='sum').to(device)
  with torch.no_grad():
    model = model.eval()
    for seq_true in dataset:
      seq_true = seq_true.to(device)
      seq_pred = model(seq_true)


      loss = criterion(seq_pred, seq_true)


      predictions.append(seq_pred.cpu().numpy().flatten())
      losses.append(loss.item())
  return predictions, losses
_, losses = predict(model, train_dataset)


sns.distplot(losses, bins=50, kde=True, label='Train',color='black');


#Visualising train loss

Threshold = 25
predictions, pred_losses = predict(model, test_normal_dataset)
sns.distplot(pred_losses, bins=50, kde=True,color='black')

correct = sum(l <= 25 for l in pred_losses)
print(f'Correct normal predictions: {correct}/{len(test_normal_dataset)}')
Correct normal predictions: 141/145
anomaly_dataset = test_anomaly_dataset[:len(test_normal_dataset)]
predictions, pred_losses = predict(model, anomaly_dataset)
sns.distplot(pred_losses, bins=50, kde=True,color='red');

correct = sum(l > 25 for l in pred_losses)
print(f'Correct anomaly predictions: {correct}/{len(anomaly_dataset)}')

Correct anomaly predictions: 145/145

def plot_prediction(data, model, title, ax):
  predictions, pred_losses = predict(model, [data])


  ax.plot(data, label='true',color='black')
  ax.plot(predictions[0], label='reconstructed',color='red')
  ax.set_title(f'{title} (loss: {np.around(pred_losses[0], 2)})')
  ax.legend()
fig, axs = plt.subplots(
  nrows=2,
  ncols=4,
  sharey=True,
  sharex=True,
  figsize=(22, 8)
)


for i, data in enumerate(test_normal_dataset[:4]):
  plot_prediction(data, model, title='Normal', ax=axs[0, i])


for i, data in enumerate(test_anomaly_dataset[:4]):
  plot_prediction(data, model, title='Anomaly', ax=axs[1, i])


fig.tight_layout();

工学博士,担任《Mechanical System and Signal Processing》《中国电机工程学报》《控制与决策》等期刊审稿专家,擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/717661.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

项目管理进阶之EVM(挣值管理)

前言 项目管理进阶系列&#xff0c;终于有时间更新啦&#xff01;&#xff01;&#xff01;欢迎持续关注哦~ 上一节博主重点讲了一个环&#xff1a;PDCA&#xff0c;无论各行各业&#xff0c;上到航空航天、下到种地种菜&#xff0c;都离不开对质量的监督和改进。这个环既是一…

轻松4步!格式工厂怎么转换mp3教会你

在数字化时代&#xff0c;音频文件格式转换变得越发重要&#xff0c;而格式工厂作为一款强大而多功能的工具&#xff0c;为我们提供了便捷的音频转换解决方案。特别是在将音频文件转换为MP3的需求上&#xff0c;格式工厂以其简便易用的特点备受欢迎。格式工厂怎么转换mp3&#…

力扣172. 阶乘后的零

Problem: 172. 阶乘后的零 文章目录 题目描述思路及解法复杂度Code 题目描述 思路及解法 1.要使得末尾出现0&#xff0c;则乘式中必须出现因子2与5&#xff1b; 2.而由于对于一个数的阶乘&#xff0c;易知因子2的个数是大于因子5的个数&#xff08;因为只要出现偶数则可以分解出…

好的品牌营销策划方案需包含哪些?

现在很多企业家和创业者&#xff0c;对品牌营销知之甚少。作为一个多年的老营销人&#xff0c;可以来谈谈我们写一个品牌方案一般包含哪些内容。 首先&#xff0c;我们必须认识到&#xff0c;品牌策划的第一步其实是市场调研。 这不仅仅是对企业的简单了解&#xff0c;更包括…

计数排序(Counting Sort)

计数排序&#xff08;Counting Sort&#xff09; 计数排序是一个非基于比较的排序算法&#xff0c;该算法于1954年由 Harold H. Seward 提出。它的优势在于在对一定范围内的整数排序时&#xff0c;快于任何比较排序算法。排序思路: 1.找出待排序数组最大值2.定义一个索引最大…

鸿蒙: 基础认证

先贴鸿蒙认证 官网10个类别总结如下 https://developer.huawei.com/consumer/cn/training/dev-cert-detail/101666948302721398 10节课学习完考试 考试 90分合格 3次机会 1个小时 不能切屏 运行hello world hvigorfile.ts是工程级编译构建任务脚本 build-profile.json5是工程…

裁剪图片的最简单方法?这四种裁剪方法真的超级简单!

裁剪图片的最简单方法&#xff1f;在丰富多彩的现代生活中&#xff0c;图片成为了我们表达、沟通甚至展示身份的重要媒介&#xff0c;然而&#xff0c;无论是出于个人审美还是专业需求&#xff0c;图片的格式和尺寸往往成为了我们不得不面对的问题&#xff0c;特别是那些未经雕…

小孟再接盲盒小程序,3天开发完!

大家好&#xff0c;我是程序员小孟。 前面开发了很多的商业的单子&#xff0c;私活联盟的小伙伴慢慢的逐渐搞自己的产品。 前面的话&#xff0c;开发了盲盒小程序&#xff0c;最近又接了一款盲盒小程序。因为前面有开发过&#xff0c;所以我们的成本也少了很多。 盲盒小程序…

2024年最佳插电式混合动力电动汽车

对电动汽车充满好奇和环保意识的司机们还没有准备好跨入纯电动汽车&#xff0c;他们可以找到一个折衷方案&#xff0c;即插电式混合动力车。 在过去的16年里&#xff0c;我一直在把握汽车行业的脉搏。试驾数百辆汽车、电动汽车、插电式混合动力车&#xff0c;跟踪汽车行业的新闻…

一行代码实现鼠标横向滚动

&#x1f9d1;‍&#x1f4bb; 写在开头 点赞 收藏 学会&#x1f923;&#x1f923;&#x1f923; 在项目中我们可能会遇到当鼠标在某个区域内&#xff0c;我们希望滚动鼠标里面的内容可以横向滚动&#xff1b; 比如我们一些常见的后台状态栏&#xff1a; 那这种该怎么写&…

GitLab安装部署以及bug修复

使用git&#xff0c;还需要一个远程代码仓库。常见的github、gitee这种远程代码仓库&#xff0c;公司中一般不会使用&#xff0c;因为他们是使用外网的&#xff0c;不够安全。一般企业都会搭建一个仅内网使用的远程代码仓库&#xff0c;最常见就是 GitLab 安装准备 需要开启s…

2024中国应急(消防)品牌巡展成都站成功召开!

汇聚品牌力量&#xff0c;共同相聚成都。6月14日&#xff0c;由中国安全产业协会指导&#xff0c;中国安全产业协会应急创新分会、应急救援产业网联合主办&#xff0c;四川省消防协会协办的“一切为了安全”2024年中国应急(消防)品牌巡展-成都站成功举办。该巡展旨在展示中国应…

NATAPP-内网穿透工具----下载与配置

NATAPP-内网穿透工具 基于ngrok的国内高速内网穿透服务&#xff0c;natapp提供了一种便利的方式&#xff0c;使得开发和测试过程更加高效&#xff0c;尤其是在需要进行远程调试或展示时。无论是进行web开发、微信和支付宝的本地开发调试&#xff0c;还是简单地从外部网络访问家…

win10 修改远程桌面端口,win10 修改远程桌面端口详细步骤

在Windows 10中修改远程桌面端口是一个涉及系统配置和网络安全的任务&#xff0c;需要谨慎操作以确保系统的稳定性和安全性。 以下是详细的步骤内容&#xff0c;供您参考&#xff1a; 一、通过注册表编辑器修改远程桌面端口 1. 打开注册表编辑器&#xff1a; - 首先&#…

美国犹他州立大学《Nature Geoscience》(IF=18)!揭示草本植物对土壤有机碳的重要贡献!

随着全球变暖的影响越来越显著&#xff0c;碳固定成为了一个备受关注的话题。在这个背景下&#xff0c;热带草原被认为是一个潜在的碳固定区域。然而&#xff0c;目前的研究主要关注于在热带草原中种植树木&#xff0c;以期望增加土壤有机碳含量。但是&#xff0c;热带草原中的…

llamaindex原理与应用简介(宏观理解)

llamaindex原理与应用简介&#xff08;宏观理解&#xff09; 文章目录 llamaindex原理与应用简介&#xff08;宏观理解&#xff09; 这是我认为对于 llamaindex 应用的场景概述讲的相对比较好的视频&#xff1a;llamaindex原理与应用简介

牛客热题:最长回文子串

&#x1f4df;作者主页&#xff1a;慢热的陕西人 &#x1f334;专栏链接&#xff1a;力扣刷题日记 &#x1f4e3;欢迎各位大佬&#x1f44d;点赞&#x1f525;关注&#x1f693;收藏&#xff0c;&#x1f349;留言 文章目录 牛客热题&#xff1a;最长回文子串题目链接方法一&am…

【Linux】环境设置MySQL表名忽略大小写

目录 说明 一、摘要 二、查看服务器上MySQL情况 方式一&#xff1a;通过Linux方式 方式二&#xff1a;借助可视化工具&#xff08;Navicat&#xff09; 三、MySQL设置忽略表名大小写的参数&#xff08;lower_case_table_names&#xff09; 四、网上解决方案 方法一&…

找不到vcomp100.dll无法继续执行代码的原因及解决方法

在日常使用电脑的过程中&#xff0c;我们可能会遇到一些错误提示&#xff0c;其中之一就是“vcomp100.dll丢失”。那么&#xff0c;vcomp100.dll是什么&#xff1f;它为什么会丢失&#xff1f;对电脑有什么具体影响&#xff1f;如何解决这个问题&#xff1f;本文将为您详细解答…

eNSP学习——配置基于接口地址池的DHCP

目录 主要命令 原理概述 实验目的 实验内容 实验拓扑 实验编址 实验步骤 1、基本配置 2、基于接口配置 DHCP Server 功能 3、配置基于接口的DHCP Server租期/DNS服务器地址 4、配置 DHCP Client 主要命令 //查看DHCP地址池中的地址分配情况 display ip pool//开启D…