可解释机器学习之SHAP方法

以Breast cancer wisconsin (diagnostic) dataset数据集为例。

# Built-in libraries
import math
import numpy    as np
import pandas   as pd




# Visualization libraries
import matplotlib.pyplot as plt
import seaborn           as sns


# Sklearn libraries
#
from sklearn                 import metrics
from sklearn                 import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.datasets        import load_breast_cancer

Parameters

test_size = 0.1

Breast cancer wisconsin (diagnostic) dataset

Data Set Characteristics:

:Number of Instances: 569

:Number of Attributes: 30 numeric, predictive attributes and the class

:Attribute Information:
    - radius (mean of distances from center to points on the perimeter)
    - texture (standard deviation of gray-scale values)
    - perimeter
    - area
    - smoothness (local variation in radius lengths)
    - compactness (perimeter^2 / area - 1.0)
    - concavity (severity of concave portions of the contour)
    - concave points (number of concave portions of the contour)
    - symmetry 
    - fractal dimension ("coastline approximation" - 1)

    The mean, standard error, and "worst" or largest (mean of the three
    largest values) of these features were computed for each image,
    resulting in 30 features.  For instance, field 3 is Mean Radius, field
    13 is Radius SE, field 23 is Worst Radius.

    - class:
            - WDBC-Malignant
            - WDBC-Benign

:Summary Statistics:

:Missing Attribute Values: None

:Class Distribution: 212 - Malignant, 357 - Benign
# Load Breast Cancer dataset
data = load_breast_cancer() 


# Create DataFrame
#
df   = pd.DataFrame(data.data, columns=data.feature_names)
# Add a target column
#
df['class'] = data.target




# Show DataFrame
df.head(3)

#Pre-processing data
fig = plt.figure(figsize=(8
,3
))ax  = sns.countplot(df['class'], order = df['class'].value_counts().index)


#Create annotate
for i in ax.patches:
    ax.text(x        = i.get_x() + i.get_width()/2, 
            y        = i.get_height()/7, 
            s        = f"{np.round(i.get_height()/len(df)*100)}%", 
            ha       = 'center', 
            size     = 15, 
            weight   = 'bold', 
            rotation = 90, 
            color    = 'white');
    


plt.title("Class variable", size=12, weight='bold');

Training/Testing sets

X = df.iloc[:,:-1]
Y = df.iloc[:, -1]
trainX, testX, trainY, testY = train_test_split(X, Y, test_size=test_size, random_state=42)

Model development

Setup ML model

from sklearn.ensemble import GradientBoostingClassifier




# XGBoost model
#
model = GradientBoostingClassifier( random_state  = 42 )

Training ML model

model.fit(trainX, trainY);
model.fit(trainX, trainY);

Get Predictions

# Calculate prediction
#
pred = model.predict( testX )




# Performance accuracy
#
accuracy = metrics.accuracy_score(testY, pred)
print("Accuracy: %.2f%%" % (accuracy * 100.0))
Accuracy: 96.49%

SHAP

%%capture
! pip install shap

import shap
# Generate the Tree explainer and SHAP values
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
expected_value = explainer.expected_value

Explainability/Visualizations

Summary dot plot

# Generate summary dot plot
#
shap.summary_plot(shap_values   = shap_values, 
                  features      = X, 
                  feature_names = df.columns[:-1], 
                  title         = 'SHAP summary plot')

Summary bar plot

# Generate summary bar plot 
#
shap.summary_plot(shap_values   = shap_values, 
                  features      = X, 
                  feature_names = df.columns[:-1], 
                  plot_type     = "bar")

Dependence plot

# Generate dependence plot
#
shap.dependence_plot(ind               = "worst concave points", 
                     shap_values       = shap_values, 
                     features          = X, 
                     feature_names     = df.columns[:-1],
                     interaction_index = "mean concave points")

Multiple dependence plots

# Generate multiple dependence plots
for name in df.columns[:-1]:
     shap.dependence_plot(ind               = name, 
                          shap_values       = shap_values, 
                          features          = X, 
                          feature_names     = df.columns[:-1],)

工学博士,担任《Mechanical System and Signal Processing》《中国电机工程学报》《控制与决策》等期刊审稿专家,擅长领域:现代信号处理,机器学习,深度学习,数字孪生,时间序列分析,设备缺陷检测、设备异常检测、设备智能故障诊断与健康管理PHM等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/722160.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

经典游戏案例:unity官方推荐3d跑酷

学习目标&#xff1a;实现跑酷核心算法 游戏画面 项目结构目录 部分核心代码 using System; using System.Collections; using System.Collections.Generic; using UnityEngine; /// <summary> /// 游戏管理器是一个状态机&#xff0c;根据当前的游戏状态&#xff0c;它…

旅游管理平台系统

摘要 如今许多地区的风景已经随着网络技术的不断发展和进步而映入人们的眼帘&#xff0c;旅游已经成为一种大众化的休闲方式。而青海海西州风光旖旎&#xff0c;民族文化独特&#xff0c;更是吸引了众多游客纷至沓来。海西州地域广阔、人烟稀少、是一个经济发展缓慢的地方&…

uni微信小程序使用lottie

在uni插件市场找到 lottie-uni https://ext.dcloud.net.cn/plugin?id1044按照文档要求安装 HBuilderX 引入 下载或导入示例获取插件 import lottie from /common/lottie-miniprogram.jsindex.vue <template><uni-popupref"popup"type"center"ba…

汽车IVI中控开发入门及进阶(二十九):i.MX6

前言: i.MX 6双/6Quad处理器集成多媒体应用处理器,是不断增长的多媒体产品系列的一部分,提供高性能处理,并针对最低功耗进行了优化。 i.MX 6Dual/6Quad处理器采用先进的quad-ArmCortex-A9内核,运行速度高达800 MHz,包括2D和3D图形处理器、1080p视频处理和集成电源管理。…

OPNsense 24.1 - 基于 FreeBSD 的开源防火墙和路由平台

OPNsense 24.1 - 基于 FreeBSD 的开源防火墙和路由平台 请访问原文链接&#xff1a;https://sysin.org/blog/opnsense/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org 关于 OPNsense OPNsense 是一个开源、易于使用且易于构建…

经典游戏案例:飞机大战

学习目标&#xff1a;实现飞机射击核心功能 游戏画面 项目结构目录 部分核心代码 using UnityEngine; using System.Collections; using UnityEngine.EventSystems; public class playermoveA :MonoBehaviour,IPointerUpHandler,IPointerDownHandler,IDragHandler{ //public …

org.eclipse.milo opcua 库查看记录

1 Reference连接 在OPC UA Server中&#xff0c;所有Node之间都是使用Reference进行连接的。 读取时指定HierarchicalReferences就可以读取HierarchicalReferences及以下所有类型的节点。 2 nodeId读取 browse 默认读取了Method、Object、Variable类型节点&#xff0c;Refer…

Python爬虫实战案例之——MySql数据入库

Hello大家好&#xff0c;我是你们的南枫学长&#xff0c;咱们今天来学——爬虫之MySql数据入库。 话不多说&#xff0c;导入咱们的老朋友&#xff1a; Pymysql就是我们Python里面的mysql库&#xff0c;主要功能就是用来连接MySql数据库&#xff0c;那么下载还是一样的操作去进…

ClickHouse安装与下载22.3.2.2

ClickHouse安装与下载 目录 1. ClickHouse简介 1.1 ClickHouse优点&#xff1a; 1.2 ClickHouse缺点&#xff1a; 1.3 ClickHouse引擎&#xff1a; 1.3.1 数据库引擎 1.3.2 表引擎 2. ClickHouse下载安装 2.1 ClickHouse下载安装 2.2 ClickHouse使用 1. ClickHouse简…

从ITIL,CMMI到DevOps的实践与思考

点击进入IT管理资料库 在信息技术迅猛发展的今天&#xff0c;企业对IT运维和管理的要求越来越高。从最早的ITIL&#xff0c;到后来的CMMI&#xff0c;再到现在风靡全球的DevOps&#xff0c;每一个管理框架的出现都代表着一种新的思维和实践模式。ITIL帮助企业建立起系统的IT服…

代码随想录算法训练营第四十二天|1049. 最后一块石头的重量 II , 494. 目标和 , 474.一和零

1049. 最后一块石头的重量 II - 力扣&#xff08;LeetCode&#xff09; class Solution {public int lastStoneWeightII(int[] stones) {if(stones.length 0){return 0;}if(stones.length 1){return stones[0];}int sum 0;for(int i0;i<stones.length;i){sum stones[i];…

leetcode118 杨辉三角

给定一个非负整数 numRows&#xff0c;生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中&#xff0c;每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]]示例 2: 输入: numRows 1 输出: [[1]] public List…

【Ubuntu】--- 创建用户 删除用户 及其他用户操作大全 持续更新中

在编程的艺术世界里&#xff0c;代码和灵感需要寻找到最佳的交融点&#xff0c;才能打造出令人为之惊叹的作品。而在这座秋知叶i博客的殿堂里&#xff0c;我们将共同追寻这种完美结合&#xff0c;为未来的世界留下属于我们的独特印记。 【Ubuntu】--- 创建用户 删除用户 及其他…

【C++】#20,#21

#20类和对象 #include <iostream>using namespace std;class Box{public: //公有 double length; //ctrle复制本行 double width;double height;void getVolume(){ //方法带&#xff08;&#xff09; cout<<"盒子体积为&#xff1a;"<<le…

threejs教程:绘制3D地图(广东省区划图)

一、效果展示&#xff1a; 二、开发准备 Three.js中文文档&#xff1a;Three.js中文网 Three.js文本渲染插件&#xff1a;Troika 3D Text - Troika JS 行政区划边界数据查询&#xff08;阿里云数据可视化平台&#xff09;&#xff1a;DataV.GeoAtlas地理小工具系列 1. 在项目…

STM32学习 时钟树

在单片机中&#xff0c;时钟的概念非常重要&#xff0c;这次记录一下时钟树相关的知识。 STM32的时钟树是由多个时钟源和时钟分频组成的&#xff0c;为STM32芯片提供各种时钟信号。也就是说&#xff0c;在使用STM32的时候&#xff0c;所有的频率和时钟都是通过时钟树产生的。 …

Maven添加reactor依赖失败

目录 情况说明 解决过程 情况说明 起初是自己在学spring boot3&#xff0c;结果到了reactor这一部分的时候&#xff0c;在项目的pom.xml文件中添加下列依赖报错&#xff1a; <dependencyManagement><dependencies><dependency><groupId>io.projectr…

github配置可拉取项目到本地

首先配置用户名和邮箱&#xff1a; git config --global user.name 自己的名字git config --global user.email 自己的邮箱配置完之后检查一下&#xff1a; git config --global user.namegit config --global user.email如果提示的是自己配置好的名字和邮箱就Ok 然后拉取githu…

NLP入门——基于梯度下降法分类的应用

问题分析 我们前面研究的都是基于统计的方法&#xff0c;通过不同的统计方法得到不同的准确率&#xff0c;通过改善统计的方式来提高准确率。现在我们要研究基于数学的方式来预测准确率。 假设我们有一个分词 s_{class,word}&#xff0c;class是该对象的类别&#xff0c;word…

数据库大作业——音乐平台数据库管理系统

W...Y的主页&#x1f60a; 代码仓库分享&#x1f495; 《数据库系统》课程设计 &#xff1a;流行音乐管理平台数据库系统&#xff08;本数据库大作业使用软件sql server、dreamweaver、power designer&#xff09; 目录 系统需求设计 数据库概念结构设计 实体分析 属性分…