基于集成学习的用户流失预测并利用shap进行特征解释

基于集成学习的用户流失预测并利用shap进行特征解释

小P:小H,如果我只想尽可能的提高准确率,有什么好的办法吗?

小H:优化数据、调参侠、集成学习都可以啊

小P:什么是集成学习啊,听起来就很厉害的样子

小H:集成学习就类似于【三个臭皮匠顶个诸葛亮】,将一些基础模型组合起来使用,以期得到更好的结果

集成学习实战

数据准备

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
import warnings
warnings.filterwarnings('ignore')

from scipy import stats
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import SMOTE 
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, KFold
from sklearn.feature_selection import RFE 
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, ExtraTreesClassifier
import xgboost as xgb
from sklearn.metrics import accuracy_score, auc, confusion_matrix, f1_score, \
    precision_score, recall_score, roc_curve  # 导入指标库
import prettytable
import sweetviz as sv # 自动eda
import toad 
from sklearn.model_selection import StratifiedKFold, cross_val_score  # 导入交叉检验算法

# 绘图初始化
%matplotlib inline
pd.set_option('display.max_columns', None) # 显示所有列
sns.set(style="ticks")
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号

# 导入自定义模块
import sys
sys.path.append("/Users/heinrich/Desktop/Heinrich-blog/数据分析使用手册")
from keyIndicatorMapping import *

上述自定义模块keyIndicatorMapping如果有需要的同学可关注公众号HsuHeinrich,回复【数据挖掘-自定义函数】自动获取~

以下数据如果有需要的同学可关注公众号HsuHeinrich,回复【数据挖掘-集成学习】自动获取~

# 读取数据
raw_data = pd.read_csv('classification.csv')
raw_data.head()

image-20230206151936701

# 缺失值填充,SMOTE方法限制非空
raw_data=raw_data.fillna(raw_data.mean())
# 数据集分割
X = raw_data[raw_data.columns.drop('churn')]
y = raw_data['churn']
# 标准化
scaler = StandardScaler() 
scale_data = scaler.fit_transform(X)  
X = pd.DataFrame(scale_data, columns = X.columns)
# 划分训练测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=.3, random_state=0)
# 过采样
model_smote = SMOTE(random_state=0)  # 建立SMOTE模型对象
X_train, y_train = model_smote.fit_resample(X_train, y_train) 

模型对比

%%time
# 初选分类模型
model_names = ['LR', 'SVC', 'RFC', 'XGBC']  # 不同模型的名称列表
model_lr = LogisticRegression(random_state=10) # 建立逻辑回归对象
model_svc = SVC(random_state=0, probability=True) # 建立支持向量机分类对象
model_rfc = RandomForestClassifier(random_state=10) # 建立随机森林分类对象
model_xgbc = xgb.XGBClassifier(use_label_encoder=False, eval_metric='auc', random_state=10) # 建立XGBC对象

# 模型拟合结果
model_list = [model_lr, model_svc, model_rfc, model_xgbc]  # 不同分类模型对象的集合
pre_y_list = [model.fit(X_train, y_train).predict(X_test) for model in model_list]  # 各个回归模型预测的y值列表
CPU times: user 2.49 s, sys: 125 ms, total: 2.62 s
Wall time: 843 ms
# 核心评估指标
metrics_dic = {
                'model_names':[],
                'auc':[],
                'ks':[],
                'accuracy':[],
                'precision':[],
                'recall':[],
                'f1':[]
            }
for model_name, model, pre_y in zip(model_names, model_list, pre_y_list):
    y_prob = model.predict_proba(X_test)  # 获得决策树的预测概率,返回各标签(即0,1)的概率
    fpr, tpr, thres = roc_curve(y_test, y_prob[:, 1])  # ROC y_score[:, 1]取标签为1的概率,这样画出来的roc曲线为正
    metrics_dic['model_names'].append(model_name)
    metrics_dic['auc'].append(auc(fpr, tpr))  # AUC
    metrics_dic['ks'].append(max(tpr - fpr)) # KS值
    metrics_dic['accuracy'].append(accuracy_score(y_test, pre_y))
    metrics_dic['precision'].append(precision_score(y_test, pre_y))
    metrics_dic['recall'].append(recall_score(y_test, pre_y))
    metrics_dic['f1'].append(f1_score(y_test, pre_y))
pd.DataFrame(metrics_dic)

image-20230206152007352

集成学习

%%time
# 建立组合评估器列表 均衡稳定性和准确性 这里只是演示,就将所有模型都纳入了
estimators = [('SVC', model_svc), ('RFC', model_rfc), ('XGBC', model_xgbc), ('LR', model_lr)]  
model_vot = VotingClassifier(estimators=estimators, voting='soft', weights=[1.1, 1.1, 0.9, 1.2],
                             n_jobs=-1)  # 建立组合评估模型
cv = StratifiedKFold(5)  # 设置交叉检验方法 分类算法常用交叉检验方法
cv_score = cross_val_score(model_vot, X_train, y_train, cv=cv, scoring='accuracy')  # 交叉检验
print('{:*^60}'.format('Cross val scores:'),'\n',cv_score) # 打印每次交叉检验得分
print('Mean scores is: %.2f' % cv_score.mean())  # 打印平均交叉检验得分
*********************Cross val scores:********************** 
 [0.73529412 0.7745098  0.85294118 0.85294118 0.87745098]
Mean scores is: 0.82
CPU times: user 2.38 s, sys: 432 ms, total: 2.81 s
Wall time: 5 s
# 模型训练
model_vot.fit(X_train, y_train)  # 模型训练
VotingClassifier(estimators=[('SVC', SVC(probability=True, random_state=0)),
                             ('RFC', RandomForestClassifier(random_state=10)),
                             ('XGBC',
                              XGBClassifier(base_score=0.5, booster='gbtree',
                                            colsample_bylevel=1,
                                            colsample_bynode=1,
                                            colsample_bytree=1,
                                            eval_metric='rmse', gamma=0,
                                            gpu_id=-1, importance_type='gain',
                                            interaction_constraints='',
                                            learning_rate=0.300000012,
                                            max...
                                            min_child_weight=1, missing=nan,
                                            monotone_constraints='()',
                                            n_estimators=100, n_jobs=8,
                                            num_parallel_tree=1,
                                            random_state=10, reg_alpha=0,
                                            reg_lambda=1, scale_pos_weight=1,
                                            subsample=1, tree_method='exact',
                                            use_label_encoder=False,
                                            validate_parameters=1,
                                            verbosity=None)),
                             ('LR', LogisticRegression(random_state=10))],
                 n_jobs=-1, voting='soft', weights=[1.1, 1.1, 0.9, 1.2])
model_confusion_metrics(model_vot, X_test, y_test, 'test')
model_core_metrics(model_vot, X_test, y_test, 'test')
confusion matrix for test
 +----------+--------------+--------------+
|          | prediction-0 | prediction-1 |
+----------+--------------+--------------+
| actual-0 |      53      |      31      |
| actual-1 |      37      |     179      |
+----------+--------------+--------------+
core metrics for test
 +-------+----------+-----------+--------+-------+-------+
|  auc  | accuracy | precision | recall |   f1  |   ks  |
+-------+----------+-----------+--------+-------+-------+
| 0.805 |  0.773   |   0.589   | 0.631  | 0.609 | 0.504 |
+-------+----------+-----------+--------+-------+-------+

可以看到集成学习的各项指标表现均优异,只有召回率低于LR

利用shap进行模型解释

shap作为一种经典的事后解释框架,可以对每一个样本中的每一个特征变量,计算出其重要性值,达到解释的效果。该值在shap中被专门称为Shapley Value。

该系列以应用为主,对于具体的理论只会简单的介绍它的用途和使用场景。这里的shap相关知识 可以参考黑盒模型事后归因解析:SHAP方法、SHAP知识点全汇总

学无止境,且学且珍惜~

# pip install shap
import shap   
# 初始化
shap.initjs()  
# 通过采样提高计算效率,但会导致准确率降低。表现在base_value与mean(model.predict_proba(X))存在差异,不建议K太小
# X_test_summary = shap.sample(X_test, 200)
# X_test_summary = shap.kmeans(X_test, 150)
explainer = shap.KernelExplainer(model_vot.predict_proba, X_test)
shap_values = explainer.shap_values(X_test, nsamples = 10)
Using 300 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
  • 单样本查看
# 单样本查看-1概率较高的样本 # 208
shap.force_plot(base_value=explainer.expected_value[1],
                shap_values=shap_values[1][208],
                features = X_test.iloc[208,:]
               )

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-AxXaXK0k-1679902455430)(null)]

  • base_value:所有样本预测值的均值,即base_value=model_vot.predict_proba(X_test)[:,1].mean()

    ⚠️注意:当进行采样或者kmean加速计算时,会损失一定准确度。即explainer带入的是X_test_summary

  • f(x):预测的实际值model_vot.predict_proba(X_test)[:,1]

  • data:样本特征值

  • shap_values:f(x)-base_value;shap值越大越红,越小越蓝

# 验证base_value
print('所有样本预测标签1的概率均值:',model_vot.predict_proba(X_test)[:,1].mean())
print('base_value:',explainer.expected_value[1])
所有样本预测标签1的概率均值: 0.3519852365700774
base_value: 0.35198523657007774

经验证,base_value计算逻辑正确

# 验证单一样本
i=208
fx=model_vot.predict_proba(X_test)[:,1][i]
da=X_test.iloc[i,:]
sv=fx-explainer.expected_value[1]
sv_val=shap_values[1][i].sum()
print('f(x):',fx)
print('shap_values:',sv,sv_val)
f(x): 0.9264517406651224
shap_values: 0.5744665040950446 0.5744665040950446

经验证,shap_values计算逻辑正确

  • 特征重要性
# 特征重要程度
shap.summary_plot(shap_values[1],X_test,max_display=10,plot_type="bar")

  • 蜂窝图体现特征重要性
# 特征与样本蜂窝图
shap.summary_plot(shap_values[1],X_test,max_display=10)

output_38_0

retention_days越大,蓝色的样本越多,表明较高的retention_days有助于缓减流失

  • 特征的shap值
# 单特征预测结果
shap.dependence_plot("retention_days", shap_values[1], X_test, interaction_index=None)

output_41_0

retention_days低的shape值较大,上面讲到shap越大越红,对于y起到提高作用。即retention_days与流失负相关

# 双特征交叉影响
shap.dependence_plot("retention_days", shap_values[1], X_test, interaction_index='level')

output_43_0

  • 在较低的retention_days(如-1.5),高level(level=1.0)的shepae值较高(红色点),在0.2附近
  • 在较高的retention_days(如1.5),高level(level=1.0)的shepae值较低(红色点),在-0.2附近

总结

集成学习能有效地提高模型的预测性能,但是使得模型内部结构更为复杂,无法直观理解。好在可以借助shap进行常见的特征重要性解释等。

共勉~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/5949.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SSM—【笔记】1.2 SpringMVC

SpringMVC:用于表现层开发,同Servlet功能等同,但比Servlet技术使用更加简便,可以用更少代码量完成开发 项目结构: 后端采用的是三层架构模式: 数据层:先学的JDBC技术,后用MyBatis框架取代 表…

ThreeJS-缩放、旋转(四)

代码&#xff1a; <template> <div id"three_div"> </div> </template> <script> import * as THREE from "three"; import {OrbitControls } from three/examples/jsm/controls/OrbitControls export default { name: &quo…

在华为做了三年软件测试被裁了,我该怎么办

近年来&#xff0c;随着经济环境的变化和企业战略的调整&#xff0c;员工被裁员的情况变得越来越普遍。无论是因为企业经营困难还是因为业务调整&#xff0c;员工们都可能面临被裁员的风险。如果你也遇到了这样的情况&#xff0c;那么你应该怎么办呢&#xff1f; 首先&#xf…

centos7 SystemV 开机自启动脚本配置方法 redis集群三主三从

centos7 SystemV 开机自启动脚本配置方法 redis集群三主三从1、安装redis集群2、编写redis启动脚本2.1、建立启动脚本2.2、复制多份redis启动脚本给集群使用2.3、添加可执行权限3、配置开机自启动1、安装redis集群 参考: redis三主三从集群安装 2、编写redis启动脚本 2.1、建…

RabbitMQ 07 发布订阅模式

发布订阅模式 发布订阅模式结构图&#xff1a; 比如信用卡还款日临近了&#xff0c;那么就会给手机、邮箱发送消息&#xff0c;提示需要去还款了&#xff0c;但是手机短信和邮件发送并不一定是同一个业务提供的&#xff0c;但是现在又希望能够都去执行&#xff0c;就可以用到发…

HTTP协议发展历程-HTTP2【协议篇】

HTTP2.0 HTTP2为了解决HTTP1.1中存在的问题。其中慢启动和TCP连接竞争是TCP本身导致的&#xff0c;在H2中依赖的还是TCP协议&#xff0c;不过思路换了一下。 HTTP/2 的思路就是一个域名只使用一个 TCP 长连接来传输数据&#xff0c;这样整个页面资源的下载过程只需要一次慢启动…

【Elastic (ELK) Stack 实战教程】04、ElasticSearch 集群进阶及优化

目录 一、ES 集群故障转移 1.1 什么是故障转移 1.2 模拟节点故障 1.2.1 重新选举 1.2.2 主分片调整 1.2.3 副本分片调整 二、ES 文档路由原理 2.1 文档的创建流程 2.2 文档的读取流程 2.3 文档批量创建的流程 2.4 文档批量读取的流程 ​三、ES扩展集群节点 3.1 …

【目标检测论文阅读笔记】Multi-scene small object detection with modified YOLOv4

Abstract. 小目标检测的应用存在于我们日常生活中的许多不同场景中&#xff0c;该课题也是目标检测与识别研究中最难的问题之一。因此&#xff0c;提高小目标检测精度不仅在理论上具有重要意义&#xff0c;在实践中也具有重要意义。然而&#xff0c;当前的检测相关算法在这项任…

Node.js学习笔记——Express.js

一、express介绍 express是一个基于Node.js平台的极简、灵活的WEB应用开发框架&#xff0c;官方网址&#xff1a;https://www.expressjs.com.cn/ 二、express使用 2.1express下载 express本身是一个npm包&#xff0c;所以可以通过npm安装。 npm init npm i express 2.2expr…

Java接口

目录 抽象类 抽象类的概述 如何使用抽象类 抽象类的使用 抽象特征 关于抽象需要注意的几个事情 接口(interface) 常量 如何实现接口 接口与接口多继承 接口的注意事项 抽象类 抽象类的概述 父类中的方法&#xff0c;被它的子类们重写&#xff0c;子类各自的实现都不…

《花雕学AI》02:人工智能挺麻利,十分钟就为我写了一篇长长的故事

ChatGPT最近火爆全网&#xff0c;上线短短两个多月&#xff0c;活跃用户就过亿了&#xff0c;刷新了历史最火应用记录&#xff0c;网上几乎每天也都是ChatGPT各种消息。国内用户由于无法直接访问ChatGPT&#xff0c;所以大部分用户都无缘体验。不过呢&#xff0c;前段时间微软正…

Vulnhub:DC-3靶机

kali&#xff1a;192.168.111.111 靶机&#xff1a;192.168.111.250 信息收集 端口扫描 nmap -A -v -sV -T5 -p- --scripthttp-enum 192.168.111.250 通过nmap得知目标CMS为Joomla 3.7.0 漏洞利用 搜索发现该版本存在sql注入 利用sqlmap获取目标后台用户密码 sqlmap -u &…

测试行业3年经验,面试想拿 17K,HR说你只值 8K,该如何回答或者反驳?

面试最尴尬的不是被拒绝&#xff0c;而是直接说你不值那个价格... 最近朋友在面试的时候&#xff0c;HR 突然来了句&#xff1a;你只值 7K。朋友后面和我说了这个事。我想如果是我处在这种情况下&#xff0c;自己并不能很好地回答或者反驳。不知道大家会怎么回答或者反驳&…

基于vivado(语言Verilog)的FPGA学习(4)——FPGA选择题总结(针对华为逻辑岗实习笔试)

基于vivado&#xff08;语言Verilog&#xff09;的FPGA学习&#xff08;4&#xff09;——FPGA选择题总结 文章目录基于vivado&#xff08;语言Verilog&#xff09;的FPGA学习&#xff08;4&#xff09;——FPGA选择题总结1. 消除险象2. 建立时间和保持时间3.ISE4.DMA5.仿真器6…

【Linux】-- 权限和Shell运行原理

目录 Shell的运行原理 用户切换 su - / su sudo 权限 chmod chown chgrp 八进制方法修改文件属性 目录权限 粘滞位 umask 自定义默认权限 Shell的运行原理 广义上&#xff0c;Linux发行版 Linux内核 外壳程序 Linux 从广义上来理解它是一个操作系统 而从狭义上…

关于Map类的使用小结

目录 1. 常用Map类和区别 2. HashMap工作原理 2.1 Put()执行过程 2.2 扩容机制 3. ConcurrentHashMap 3.1 工作原理 3.2 JDK7分段锁的优缺点 1. 常用Map类和区别 Map类包含&#xff1a;HashMap、HashTable、LinkedHashMap、TreeMap。 1) 从功能上区分。 HashMap&…

多线程进阶学习11------CountDownLatch、CyclicBarrier、Semaphore详解

CountDownLatch ①. CountDownLatch主要有两个方法,当一个或多个线程调用await方法时,这些线程会阻塞 ②. 其它线程调用countDown方法会将计数器减1(调用countDown方法的线程不会阻塞) ③. 计数器的值变为0时,因await方法阻塞的线程会被唤醒,继续执行 public static void m…

SpringBoot学习笔记上

文章目录1 SpringBoot1.1 SpringBoot介绍1.2 SpringBoot创建的三种方式1.3SpringBootApplication注解1.4 SpringBoot的配置文件1.5多环境配置1.6 使用jsp1.7 ComnandLineRunner 接口 &#xff0c; ApplcationRunner接口2 Web组件2.1 拦截器2.2 Servlet2.3 过滤器Filter2.4 字符…

gpt3官网中文版-人工智能软件chat gpt安装

GPT-3&#xff08;Generative Pre-trained Transformer 3&#xff09;是一种自然语言处理模型&#xff0c;由OpenAI研发而成。它是GPT系列模型的第三代&#xff0c;也是目前最大、最强大的自然语言处理模型之一&#xff0c;集成了1750亿个参数&#xff0c;具有广泛的使用场景&a…

Flutter Row 实例 —— 新手礼包

大家好&#xff0c;我是 17。 本文在 3.31 日全站综合热榜第一。 新手礼包一共 3 篇文章&#xff0c;每篇都是描述尽量详细&#xff0c;实例讲解&#xff0c;包会&#xff01; Flutter Row 实例 —— 新手礼包Flutter TextField UI 实例 —— 新手礼包Flutter TextField 交…