利用SHAP算法解释BERT模型的输出

1 何为SHAP?

传统的 feature importance 只告诉哪个特征重要,但并不清楚该特征如何影响预测结果。SHAP 算法的最大优势是能反应每一个样本中特征的影响力,且可表现出影响的正负性。SHAP算法的主要思想为:控制变量法,如果某个特征出现或不出现,直接影响分类结果,那么该特征一定是比较重要的。因此,可以通过计算该特征出现或不出现的各种情况,来计算其对于分类结果的贡献度。在 SHAP 算法中用沙普利(Shapley )值表示不同特征对于预测结果的贡献度。Shapley 值是博弈论中使用的一种方法,它涉及公平地将收益和成本分配给在联盟中工作的行动者,由于每个行动者对联盟的贡献是不同的,Shapley 值保证每个行动者根据贡献的多少获得公平的份额。

2 代码实现

接下来展示如何用 SHAP 来解释基于 BERT 的文本分类任务,直接从 SHAP官网上扒下来代码:

import nlp
import numpy as np
import scipy as sp
import torch
import transformers
from transformers import BertTokenizer, BertForSequenceClassification
import shap

# load a BERT sentiment analysis model
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained(
    "distilbert-base-uncased"
)
model = transformers.DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
).cuda()

# define a prediction function
def f(x):
    tv = torch.tensor(
        [
            tokenizer.encode(v, padding="max_length", max_length=512, truncation=True) for v in x
        ]
    ).cuda()
    outputs = model(tv)
    outputs = outputs[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores[:, 1])  # use one vs rest logit units
    return val
    
# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer)

# explain the model's predictions on IMDB reviews
imdb_train = nlp.load_dataset("imdb")["train"]
shap_values = explainer(imdb_train[:10], fixed_context=1, batch_size=2)

执行下列代码,用于展示数据集中第3个样本中不同特征对于预测结果的贡献度/值:

#plot the first sentence's explanation
shap.plots.text(shap_values[3])

shap.plots.bar(shap_values.abs.sum(0))

运行得到下列输出:

shap.plots.bar(shap_values.abs.max(0))

3 参考:

[1] text plot — SHAP latest documentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/757064.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

python系列30:各种爬虫技术总结

1. 使用requests获取网页内容 以巴鲁夫产品为例,可以用get请求获取内容: https://www.balluff.com.cn/zh-cn/products/BES02YF 对应的网页为: 使用简单方法进行解析即可 import requests r BES02YF res requests.get("https://www.…

JavaSE主要内容(全套超完整)

一、为什么选择Java(Java的优势) 1、应用面广: 相较于其他语言,Java的应用面可谓是非常广,这得益于他的跨平台性和其性能的稳定性。他在服务器后端,Android应用开发,大数据开发&#xf…

FastAPI-Cookie

fastapi-learning-notes/codes/ch01/main.py at master Relph1119/fastapi-learning-notes GitHub 1、Cookie的作用 Cookie可以充当用户认证的令牌,使得用户在首次登录后无需每次手动输入用户名和密码,即可访问受限资源,直到Cookie过期或…

设计模式——责任链

责任链模式是一种行为设计模式,用于将请求的发送者和接收者解耦。在这种模式中,请求通过一条由多个对象组成的链传递,直到有一个对象能够处理该请求为止。每个对象都可以决定是否处理请求以及是否将请求传递给下一个对象。 责任链模式通常在…

数字时代的软件架构:持续架构的兴起与架构师角色的转变

在数字化浪潮的推动下,软件架构领域正经历着前所未有的变革。Eoin Woods在《数字时代的软件架构》演讲中,深入探讨了这一变革,并提出了“持续架构”这一概念。本文将基于Eoin的观点,结合个人理解,探讨持续架构的重要性…

2000-2021年县域金融机构存贷款数据

2000-2021年县域金融机构存贷款数据 1、时间:2000-2021年 2、指标:统计年度、地区编码ID、县域代码、县域名称、所属地级市、所属省份、年末金融机构贷款余额/亿元、年末金融机构存款余额/亿元、年末城乡居民储蓄存款余额/亿元 3、来源:县…

音频Balance源码总结

音频Balance源码总结 何为音频Balance? 顾名思义,Balance及平衡,平衡也就是涉及多方,音频左右甚至四通道,调节所有通道的音量比,使用户在空间内听到各个通道的音频大小不一,好似置身于真实环境…

姚期智、张亚勤、薛澜、Stuart Russell、Max Tegmark,DeepMind研究员等共话全球AI治理丨大会回顾...

为什么AI安全已迫在眉睫?如何构建全球范围内的合作?民众该如何参与到其中?未来的AI系统将是什么样的? 2024年6月15日,智源大会第二天,多位AI安全领域专家进行圆桌讨论,连接中国北京和美国加利福…

Android隐藏状态栏和修改状态栏颜色_亲测有效

本文记录了隐藏状态栏和修改状态栏颜色以及电量、WiFi标志等内容的模式显示,亲测有效。 1、隐藏屏幕状态栏 public void hideStatusBar(BaseActivity activity) {Window window activity.getWindow();//没有这一行无效window.addFlags(WindowManager.LayoutParam…

基于自组织长短期记忆神经网络的时间序列预测(MATLAB)

LSTM是为了解决RNN 的梯度消失问题而诞生的特殊循环神经网络。该网络开发了一种异于普通神经元的节点结构,引入了3 个控制门的概念。该节点称为LSTM 单元。LSTM 神经网络避免了梯度消失的情况,能够记忆更长久的历史信息,更能有效地拟合长期时…

Spring Cloud LoadBalancer基础入门与应用实践

官网地址:https://docs.spring.io/spring-cloud-commons/reference/spring-cloud-commons/loadbalancer.html 【1】概述 Spring Cloud LoadBalancer是由SpringCloud官方提供的一个开源的、简单易用的客户端负载均衡器,它包含在SpringCloud-commons中用…

[OtterCTF 2018]Recovery

里克必须找回他的文件!用于加密文件的随机密码是什么 恢复他的文件 ,感染的文件 ? vmware-tray.ex 前面导出的3720.dmp 查找一下 搜索主机 strings -e l 3720.dmp | grep “WIN-LO6FAF3DTFE” 主机名 后面跟着一串 代码 aDOBofVYUNVnmp7 是不…

C++并发之环形队列(ring,queue)

目录 1 概述2 实现3 测试4 运行 1 概述 最近研究了C11的并发编程的线程/互斥/锁/条件变量,利用互斥/锁/条件变量实现一个支持多线程并发的环形队列,队列大小通过模板参数传递。 环形队列是一个模板类,有两个模块参数,参数1是元素…

【操作系统期末速成】 EP02 | 学习笔记(基于五道口一只鸭)

文章目录 一、前言🚀🚀🚀二、正文:☀️☀️☀️2.1 考点二:操作系统的功能及接口2.2 考点三:操作系统的发展及分类2.3 考点四:操作系统的运行环境(重要) 一、前言&#x…

C++输出彩色方块

1.使用方法 SetConsoleTextAttribute(GetStdHandle(STD_OUTPUT_HANDLE), 0xab); ———————————————————————————————————————— 0 黑色 1 蓝色 2 绿色 3 湖蓝色 4 红色 5 紫色 6 黄色 7 白色 8 灰色 9 …

【限免】线性调频信号的脉冲压缩及二维分离SAR成像算法【附MATLAB代码】

文章来源:微信公众号:EW Frontier QQ交流群:949444104 程序一 对线性调频信号进行仿真,输出其时频域的相关信息,并模拟回波信号, 对其进行脉冲压缩和加窗处理。 实验记录: 1.线性调频信号时…

24年hvv前夕,微步也要收费了,情报共享会在今年结束么?

一个人走的很快,但一群人才能走的更远。吉祥同学学安全https://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247483727&idx1&sndb05d8c1115a4539716eddd9fde4e5c9&scene21#wechat_redirect这个星球🔗里面已经沉淀了: 《Ja…

自闭症早期风险判别和干预新路径

谷禾健康 自闭症谱系障碍 (ASD) 是一组神经发育疾病,其特征是社交互动和沟通的质量障碍、兴趣受限以及重复和刻板行为。 环境因素在自闭症中发挥重要作用,多项研究以及谷禾队列研究文章表明肠道微生物对于自闭症的发生和发展以及存在明显的菌群和代谢物的…

智慧校园-报修管理系统总体概述

智慧校园报修管理系统是专为优化教育机构内部维修报障流程而设计的信息化解决方案,它通过集成现代信息技术,为校园设施的维护管理带来革新。该系统以用户友好和高效运作为核心,确保了从报修请求提交到问题解决的每一个步骤都顺畅无阻。 师生或…

Apache IoTDB 监控详解 | 分布式系统监控基础

IoTDB 分布式系统监控的基础“须知”! 我这个环境的系统性能一直无法提升,能否帮我找到系统的瓶颈在哪里? 系统优化后,虽然写入性能有所提升,但查询延迟却增加了,下一步我该如何排查和优化呢? 请…