一维卷积神经网络的特征可视化

随着以深度学习为代表的人工智能技术的不断发展,许多具有重要意义的深度学习模型和算法被开发出来,应用于计算机视觉、自然语言处理、语音处理、生物医疗、金融应用等众多行业领域。深度学习先进的数据挖掘、训练和分析能力来源于深度神经网络的海量模型参数以及高度非线性。也正因为深度学习算法的高度复杂性,许多模型往往难以解释其内部工作原理,这导致这些模型被称为缺乏可解释性的“黑箱模型”。

随着AI应用渗透到各行各业,AI的科技伦理受到广泛的关注。而科技伦理的一个核心议题就是可解释人工智能XAI。从社会科学角度,可解释性是指人对决策原因的理解程度,可解释性越高,人就越能理解为什么做出这样的决策。对应于AI领域,可解释性是指能够在一定程度上揭示AI模型内部工作机制和对模型结果的进行解释,帮助用户理解模型是如何做出预测或决策的。

因此,本文简单地对一维卷积神经网络的特征进行可视化,运行环境为Python,研究对象为心电信号。

首先导入相关库

import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import signal_screen
import signal_screen_tools

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv1D, MaxPool1D, Flatten, BatchNormalization, Input
from tensorflow.keras.callbacks import ModelCheckpoint

数据导入及处理

# load data
data_train = pd.read_csv("mitbih_train.csv", sep=",", header=None).to_numpy()
data_test = pd.read_csv("mitbih_test.csv", sep=",", header=None).to_numpy()

# get X and y
X_train, y_train = data_train[:, :data_train.shape[1]-2], data_train[:, -1]
X_test, y_test = data_test[:, :data_test.shape[1]-2], data_test[:, -1]

# number of categories
num_of_categories = np.unique(y_train).shape[0]

del data_train, data_test

#indexing examples to show visualisations
examples_to_visualise = [np.where(y_test == i)[0][0] for i in range(5)]
titles = [ "nonectopic", "supraventricular ectopic beat", "ventricular ectopic beat", "fusion beat", "unknown"]

# creation of tensors
X_train = np.expand_dims(tf.convert_to_tensor(X_train), axis=2)
X_test = np.expand_dims(tf.convert_to_tensor(X_test), axis=2)

# one-hot encoding for 5 categories
y_train = tf.one_hot(y_train, num_of_categories)
y_test = tf.one_hot(y_test, num_of_categories)

建立模型并进行训练

# basic model
model = Sequential([
    Input(shape=[X_train.shape[1], 1]),
    Conv1D(filters=16, kernel_size=3, activation="relu"),
    BatchNormalization(),
    MaxPool1D(),
    Conv1D(filters=32, kernel_size=3, activation="relu"),
    BatchNormalization(),
    Conv1D(filters=64, kernel_size=3, activation="relu"),
    BatchNormalization(),
    Flatten(),
    Dense(20, activation="relu"),
    Dense(num_of_categories, activation="softmax")
]
)

# train process

model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

checkPoint = ModelCheckpoint(filepath="model.h5", save_weights_only=False, monitor='val_accuracy',
                            mode='max', save_best_only=True)

model.fit(x=np.expand_dims(X_train, axis=2), y=y_train,
          batch_size=128, epochs=10, validation_data=(np.expand_dims(X_test, axis=2), y_test),
          callbacks=[checkPoint])

model = tf.keras.models.load_model("model.h5")
loss, acc = model.evaluate(np.expand_dims(X_test, axis=2), y_test)

采用Occlusion Sensitivity方法进行可视化,相关的参考文献较多。

fig, axs = plt.subplots(nrows=5, ncols=1)
fig.suptitle("Occlusion sensitivity")
fig.tight_layout()
fig.set_size_inches(10, 10)
axs = axs.ravel()

for c, row, ax, title in zip(range(5), examples_to_visualise, axs, titles):
    sensitivity, _ = signal_screen.calculate_occlusion_sensitivity(model=model,
                                                         data=np.expand_dims(X_test[row, :], axis=(0, 2)),
                                                         c=c,
                                                         number_of_zeros=[15])

    # create gradient plot
    signal_screen_tools.plot_with_gradient(ax=ax, y=X_test[row, :].ravel(), gradient=sensitivity[0], title=title)
    ax.set_xlabel("Samples[-]")
    ax.set_ylabel("ECG [-]")

plt.show()

采用Saliency map方法进行可视化。

采用Grad-CAM方法进行可视化。

工学博士,担任《Mechanical System and Signal Processing》审稿专家,担任《中国电机工程学报》优秀审稿专家,《控制与决策》,《系统工程与电子技术》,《电力系统保护与控制》,《宇航学报》等EI期刊审稿专家。

擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/510236.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于ssm的三省学堂-学习辅助系统(java项目+文档+源码)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于ssm的三省学堂-学习辅助系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 三省学堂-学习辅助系统的…

公司只有一个测试,要怎么继续呆下去?

在面试的时候,面试官可能会问:小公司、小团队,岗位就你一个人,怎么做 ? 或者已经有的小伙伴已经在公司中面临只有一个测试的处境,这个时候我们应该怎么处理呢? 一 原因分析 公司只有一个测试人…

Golang 哈希表底层实现原理

1、本文讨论Golang的哈希表 Golang哈希表的实现,底层数据结构是数组单链表,链表节点由8个key、value和键的高八位组成的。为了方便理解,先简单看一个图快速理解。 我们来看一下Golang哈希表的结构体定义 简单介绍一下结构体中几个关键的…

基于SpringBoot+Vue在线考试系统设计与实现+搭建视频

介绍 该在线考试系统共包含三种角色,分别是:学生、老师和管理员,不同角色对系统的功能需求也不同。 具体功能如下: 1)学生 考生注册、考生登录、在线考试、我的成绩、我的题库、修改个人资料等功能。 2&#xff0…

网络攻防中之url跳转过程分析和使用欺骗方法生成自己的恶意链接过程,以及点击劫持和绕过验证的几种方式

网络攻防中之url跳转过程分析和使用欺骗方法生成自己的恶意链接过程,以及点击劫持和绕过验证的几种方式。 URL跳转过程分析 URL跳转是Web应用中常见的一种行为,它通常通过HTTP重定向来实现。在网络攻防中,分析URL跳转过程对于理解应用的行为和识别潜在的安全漏洞至关重要。 …

吴恩达2022机器学习专项课程(一) 4.6 运行梯度下降第一周课程实验:线性回归的梯度下降算法

问题预览/关键词 更新梯度下降对模型拟合,等高线图,3d空间图的变化。什么是批量梯度下降。实验目标计算梯度运行梯度下降梯度下降迭代次数和成本函数的关系可视化模型预测在等高线图上的梯度下降学习率过大报错问题 笔记 1.模型拟合,等高线…

怎么修改图片的创建日期和修改日期?

怎么修改图片的创建日期和修改日期?大家都应该知道,电脑上的任何一种文件都有创建日期和修改日期,不管word、excel、ppt还是图片,这两个时间属性是都必须具备的。在数字时代,我们经常使用照片来记录珍贵的时刻和重要的…

VISA、masterCard卡进行USDT消费,无需实名,0年费,0月费

开卡流程 1、点击获取卡 2、注册之后点击“流量钱包->点击点此充值” 3、选择积分充值点击确认即可 在返回到首页点击申请卡,选择534786与556150都可以,选择钱包支付即可 点击获取卡片

【Java代码审计】SpEL表达式注入篇

【Java代码审计】SpEL表达式注入篇 1.SpEL 介绍2.SpEL漏洞概述3.SpEL漏洞演示4.SpEL漏洞修复 1.SpEL 介绍 Spring 表达式语言是一种功能强大的表达式语言,用于在运行时查询和操作对象视图,语法上类似于 Unified EL,但提供了更多的特性&#…

Flutter应用在苹果商店上架前的准备工作与注意事项

引言 🚀 Flutter作为一种跨平台的移动应用程序开发框架,为开发者提供了便利,使他们能够通过单一的代码库构建出高性能、高保真度的应用程序,同时支持Android和iOS两个平台。然而,完成Flutter应用程序的开发只是第一步…

【Servlet】继承关系以及service方法

文章目录 一、继承关系二、相关方法 一、继承关系 Servlet接口下有一个GenericServlet抽象类。在GenericServlet下有一个子类HttpServlet,它是基于http协议。 继承关系 javax.servlet.Servlet接口​ javax.GenericServlet抽象类​ javax.servlet.http.HttpServ…

毕马威:《智慧之眼:开启汽车感知新时代》

在全球科技飞速发展和产业革新的大潮中,汽车产业正在以前所未有的速度向网联化、智能化的方向转型。汽车传感器作为智能联网汽车发展的关键环节之一,扮演着举足轻重的角色。 毕马威一直关注汽车产业的变化与发展,为了更好地为汽车行业赋能&a…

python 哔哩哔哩视频去水印

使用python 去除视频中的水印 1. 需要安装的包 pip install moviepy pip install numpy pip install opencv_python pip install tqdm 2. 代码 import cv2 import numpy as np import glob from moviepy.editor import VideoFileClip import os from tqdm import tqdm# 判…

第九届全国大学生GIS应用技能大赛试题答案及数据下载(下午)

一、案例背景 我们现在是江苏省城市研究科研项目组的一员,我们分配到了以下任务: **任务一:**创建三甲医院 20 分钟、45 分钟服务区,并计算每一个地级市/县级市拥有的三甲医院 20 钟、45 分钟服务区占全市面积百分比。 **任务二&a…

ERROR: No matching distribution found for sklearn.cross_validation

问题 ERROR: Could not find a version that satisfies the requirement sklearn.cross_validation (from versions: none) ERROR: No matching distribution found for sklearn.cross_validation 错误:找不到满足sklearn要求的版本。Cross_validation (from versions: none)…

Nginx开发实战三:替换请求资源中的固定数据

文章目录 1.效果预览2.下载Nginx解压并初始化3.字符串替换模块安装4.修改nginx配置文件并重启 1.效果预览 页面初始效果 页面替换后效果 说明:页面是内网的一个地址,我们通过nginx可以很便捷的将其改为外网访问,但是在外网访问这个地址后&#xff0c…

算法之美:二叉堆原理剖析及堆应用案例讲解及实现

什么是堆 堆(Heap)是计算机科学中一类特殊的数据结构,通常是一个可以被看做一棵完全二叉树的数组对象。 完全二叉树 只有最下面两层节点的度可以小于2,并且最下层的叶节点集中在靠左连续的边界,只允许最后一层有空缺结点且空缺在右边&#x…

工艺品wordpress外贸主题

工艺品wordpress外贸主题 简约大气的wordpress外贸主题,适合做工艺品进出品外贸的公司官网使用。 https://www.jianzhanpress.com/?p5377

记一次 pdfplumber 内存泄漏导致的服务器宕机

有一个项目需求,要在每天凌晨5点的时候执行一个任务,获取一系列的PDF文件并解析。 后端是Django框架,定时任务用Celery来实现的。 本地跑没什么问题,但是一放到服务器上跑就会宕机,而且是毫无征兆的宕机,…

前端学习<二>CSS基础——17-CSS3的常见边框汇总

CSS3 常见边框汇总 <!DOCTYPE html><html lang"en"><head><meta charset"UTF-8"><title>CSS3 边框</title><style>body, ul, li, dl, dt, dd, h1, h2, h3, h4, h5 {margin: 0;padding: 0;}​body {background-c…