Post-hoc Concept Bottleneck Models (PCBM)

ICLR 2023 spotlight

文章链接:https://arxiv.org/abs/2205.15480

代码链接:https://github.com/mertyg/post-hoc-cbm

一、概述

        Post-hoc CBM(PCBM)也是CBM大家族中的一员,因此它的基本逻辑与CBM一致,就是在输入和输出之间构造一个bottleneck用于预测concepts。和其它很多文章类似,作者同样指出了CBM模型的缺点:

        (i) dense annotation,即需要大量精细的标注;

        (ii)  accuracy-interpretability trade-off,即准确性与可解释性之间的取舍与权衡(尤其是在concepts not enough的情况下);

        (iii) local intervention,即CBM只是针对个例进行干预,而不是提升模型本身的效果。

        因此,本文提出PCBM,可以将任何网络转化为PCBM,且在不牺牲模型精度的同时保证可解释性;此外,当训练集中缺失annotation时,PCBM可以从其它数据集或使用多模态模型产生概念:transfer concepts from other datasets or from natural descriptions of concepts via multimodal models”——在介绍CBM那篇文章的时候提到过——或者,引入一个residual modeling step来recover the original blackbox model's performance。此外,PCBM允许global model edits即全局的intervention,这种方法会比针对specific prediction的local intervention更加有效。


二、方法

        We let f:\mathcal{X}\rightarrow\mathbb{R}^{d} be any pretrained backbone model, where d is the size of the corres-ponding embedding space and \mathcal{X} is the input space.  f 可以是CLIP中的image encoder或者ResNet的倒数第二层(总之是一个编码器)。

        建立PCBM需要以下几个步骤:

(i) Learning the Concept Subspace

        为了学习concept representations,作者使用了CAVs的做法,首先定义了一个概念集合concept library I=\left \{ i_1,i_2,...,i_{N_c} \right \},其中 N_c 代表concepts的总数;concept library可以由domain expert定义或者从数据中自动学习(参考NeurIPS 2019, Towards automatic concept-based explanations.https://arxiv.org/abs/1902.03129)。

        For each concept i, we collect embeddings for the positive examples, denoted by the set P_i, and negative examples N_i.

P_i=\left \{ f(x_{p_1}),...,f(x_{p_{N_p}}) \right \}

N_i=\left \{ f(x_{n_1}),...,f(x_{n_{N_n}}) \right \}

        作者训练了一个SVM对 P_i 与 N_i 分类,并计算对应的CAV(分类边界的法向量),并且与TCAV相同,CAV的学习并不局限于the data used to train the backbone model;将第 i 个concept对应的CAV记为 \boldsymbol{c}_i,let \boldsymbol{C}\in \mathbb{R}^{N_c\times d} denote the matrix of concept vectors. \boldsymbol{C} 的每一行就代表第 i 个concept对应的CAV \boldsymbol{c}_i

        现在,我们有一个backbone model f 作为encoder,一个由一系列CAVs组成的concept matrix \boldsymbol{C}。此时给定输入 x,我们可以通过 f_{\boldsymbol{C}}(x)=\mathrm{proj}_{\boldsymbol{C}}f(x)\in\mathbb{R}^{N_c}将 f(x) 投影到由 \boldsymbol{C} 张成的向量空间,i.e., f_{\boldsymbol{C}}^{(i)}(x)=\frac{\left \langle f(x),\boldsymbol{c}_i \right \rangle}{\left \| \boldsymbol{c_i} \right \|_{2}^{2}}\in\mathbb{R},即 f_{\boldsymbol{C}}^{(i)}(x) 代表当前输入在第 i 个concept vector \boldsymbol{c}_i 方向上的长度(是一个scalar),直观来说就是当前输入 x 中包含概念 \boldsymbol{c}_i 的程度(图中红色方框)👇

(ii) Leveraging multimodal models to learn concepts

        前面提到CBM需要dense annotation,限制了实际应用。作者提出可以使用多模态模型比如CLIP来生成concept vector,具体来说,由于CLIP (Radford et al., 2021)具有一个image encoder和一个text encoder可以将二者编码到shared embedding space中,因此我们可以通过mapping the prompt using the text encoder to obtain the concept vectors;举例来说,如果我们想得到“strpes”这一concept对应的CAV但是又缺少标注好的数据,我们可以通过将“stripes”输入到CLIP的text encoder中,使用其编码后得到的向量作为CAV(其实就不叫CAV了,但是得到的这个向量也是类似CAV的一种用来表示概念的向量;为方便理解,此处索性就统一叫作CAV,但不要混淆),i.e. \boldsymbol{c}_{\textrm{stripes}}^{\textrm{text}}=f_{\textrm{text}}(\textrm{"stripes"});这样,对于每一个concept我们都有对应的语言表述,也都能相应地得到CAV,由此得到我们的multimodal concept bank \boldsymbol{C}^\textrm{text}.

Note:CAVs与Multimodal Models两种方法二选一,而不是将两种方法得到的CAV求并。

        对于classification task,可以使用ConceptNet (Speer et al., 2017)来自动获取与类别相关的concepts,从而构建concept bank。

(iii) Learning the Interpretable Predictor

        Let g:\mathbb{R}^{N_c}\rightarrow \mathcal{Y} be an interpretable predictor. g 可以选择线性模型或者决策树这种具有较强可解释性的模型,将预测得到的评分 f_{\boldsymbol{C}}(x) 映射为最终的类别 \mathcal{Y}。通过优化以下式子来学习模型:

\min\limits_g \mathbb{E}_{(x,y)\sim \mathcal{D}}[\mathcal{L}(g(f_{\boldsymbol{C}}(x)),y)]+\frac{\lambda }{N_cK}\Omega (g)

\Omega (g)=\alpha \left \| \omega \right \|_1(1-\alpha )\left \| \omega \right \|_{2}^{2}

        前面一项对应分类损失(如交叉熵),后面一项为正则项,用来限制predictor g 的复杂度,并由类别和概念的数量进行归一化。在这项工作中作者使用的是sparse linear models。

(iv) Recovering the original model performance with residual modeling

        即使我们拥有了一个相对丰富的概念子空间,概念很可能仍然不足以解决我们感兴趣的下游任务。对于这种情况,即PCBM与原始模型性能不匹配时,作者引入了从original embedding连接到最终决策层的残差部分,以保持原有模型的准确度,对应的模型为PCBM-h。此时,作者使用sequential的训练方式,首先训练 interpretable predictor g ,然后固定concept bottleneck and the interpretable predictor并优化残差部分:

\min\limits_r \mathbb{E}_{(x,y)\sim \mathcal{D}}[\mathcal{L}(g(f_{\boldsymbol{C}}(x))+r(f(x)),y)]

        其中 r 是residual predictor,其输入是原始的不具有解释性的embedding,而最后的输出结果是综合了interpretable predictor的输出 g(f_{\boldsymbol{C}}(x))以及residual predictor的输出 r(f(x))。可以将r(f(x)) 视为原来interpretable predictor的一种补充;g的输入是interpretable concept embeddings,r 的输入是uninterpretable的original embeddings from backbone encoder. 模型的决策由 g 尽量解释,解释不了的由 r 来恢复原始精度。很显然,PCBM-h的精度一定是高于PCBM的。

Note:如果想观察interpretable predictor g 的表现,那么就把residual predictor r 网络中的参数全部置零从而drop掉这一支路,如果我们想得到一个黑盒模型,就把 g 网络中的参数全部置零。


三、实验及结果

(i) PCBMs achieve comparable performance to the original model

        PCBMs获得了与黑盒模型comparable的性能,尤其是PCBM-h。

(ii) PCBMs achieve comparable performance to the original model

        当提供的concepts not available or insufficient的时候,可以使用借助CLIP的text ecncoder产生的concept bank,发现CLIP自动生成的concept要比人为提供的概念标注更好。

(iii) Explaining Post-hoc CBMs

        展示了针对于一个类别线性层中权重最大的三个concepts,在皮肤癌的例子中,模型考虑的concept与人类判断时考虑的因素一致。

(iv) Model editing

        与基本的CBM对单个样本做干预(local intervention)不同,PCBM的一个优势就是允许global intervention从而直接提升模型整体的表现。当我们知道某些概念是错误的时候,可以通过剪枝(Prune)等操作优化模型。举个例子,如果训练集和测试集存在域偏差,比如,训练集中有很多“狗”的图片,但是在测试集中没有“狗”的图片,那么在训练阶段学习到的所有关于狗的概念都将无效,或者说对于测试集是“错误的概念”;此时我们可以采用以下三种strategies对模型进行修改:

        (1) Prune: 在决策层将错误概念对应的权重置0,i.e., for a concept indexed by i, we let \tilde{\boldsymbol{\omega }}_i=0

        (2) PruneNormalize:在prune后rescale the concept weights,归一化可以缓解剪枝后较大权重造成的权值不平衡问题;

        (3) Fine-tune (Oracle):在测试集上对整个模型进行微调,作为oracle。

        可以发现PCBM进行PruneNormalize之后的增益较高,最接近oracle;而PCBM-h的增益很低。一个原因是PCBM可以通过Prune直接剪掉干扰预测的错误概念,但是由于PCBM-h的残差连接中仍包含来自错误概念的信息无法被去除,因此预测精度的提升不明显。

(v) User study

        作者还进行了user-study,即测试集与训练集存在偏差时,让user自行选择一定数量的concepts进行prune,观察模型性能是否有提高,以验证模型能够良好的与人类进行交互;作者使用了三个实验设置作为对比:

        (1) Random Pruning:随机对weights置零;

        (2) Greedy pruning(Oracle):即prune掉与人类同样数量的concepts使得模型得到最佳增益;

        (3) Fine-tune (Oracle):在测试集上微调。

        Random prune发生了性能降低,而user prune可以明显改善模型性能,大概相当于80%的greedy prune增益与50%的fine-tune增益。

        另一个现象是即使有残差连接但是仍然可以通过剪枝提高PCBM-h的性能,具体原因不知道。


        最后是简单的discussion:

        (1) 人类构建的concept bottleneck是否可以解决更大规模的任务是一个悬而未决的问题(例如ImageNet级别),因为会有information bottleneck的存在,精度concept定义insufficient,也是导致accuracy-interpretability之间有trade-off的原因所在。

        (2) 以无监督的方式为模型寻找概念子空间是一个活跃的研究领域,它将有助于构建更加有用的、丰富的概念瓶颈。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/284185.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

从登录测试谈测试用例

谈谈登录测试: 可能你会说,“用户登录”这个测试对象也有点太简单了吧,我只要找一个用户,让他在界面上输入用户名和密码,然后点击“确 认”按钮,验证一下是否登录成功就可以了。的确,这构成了一…

C/C++ 函数的默认参数

下面介绍一项新内容 - 默认参数。 默认参数指的是当函数调用中省略了实参时自动使用的一个值。 例如,如果将 void wow (int n)设置成n 有默认值为1,则函数调用 wow()相当于 wow(1)这极大地提高了使用函数的灵活性。 假设有一个名为left()的函数&#xff…

构建安全的SSH服务体系

某公司的电子商务站点由专门的网站管理员进行配置和维护,并需要随时从Internet进行远程管理,考虑到易用性和灵活性,在Web服务器上启用OpenSSH服务,同时基于安全性考虑,需要对 SSH登录进行严格的控制,如图10…

记一次JSF异步调用引起的接口可用率降低 | 京东云技术团队

前言 本文记录了由于JSF异步调用超时引起的接口可用率降低问题的排查过程,主要介绍了排查思路和JSF异步调用的流程,希望可以帮助大家了解JSF的异步调用原理以及提供一些问题排查思路。本文分析的JSF源码是基于JSF 1,7.5-HOTFIX-T6版本。 起因 问题背景…

【MATLAB】【数字信号处理】基本信号的仿真与实现

目的 1、用MATLAB软件实现冲激序列 2、用MATLAB软件实现阶跃序列 3、用MATLAB软件实现指数序列 4、用MATLAB软件实现正弦序列 内容与测试结果 1、用MATLAB软件实现冲激序列 程序如下: % 1 冲激序列 clc; clear all; n0 -10; nf 50; ns 1; A 1;%起点为-1&…

SpringBoot灵活集成多数据源(定制版)

如来说世界,非世界,是名世界 如来说目录,非目录,是名目录 前言前期准备代码实现演示扩展 前言 本篇博客基于SpringBoot整合MyBatis-plus,如果有不懂这个的, 可以查看我的这篇博客:快速CRUD的秘诀…

Linux 权限掌控术:深入探索和用户管理

文章目录 前言1.外壳程序是什么?外壳程为什么存在?工作原理外壳程序怎么个事? 2. Linux权限的概念2.1 什么是权限2.2权限的本质2.3 Linux中的用户 3. 普通用户变成rootlinux中有三种人 4.Linux中文件的权限4.1文件的属性权限4.2 掌握修改权限…

数字集成系统设计——逻辑综合

目录 一、概述 1.1 综合的分类 1.2 逻辑综合的基本架构 1.3 逻辑综合的内部流程 1.3.1 RTL代码转译(Translation) 1.3.2 逻辑级优化(Optimization) 1.3.3 工艺映射(Mapping) 二、优化策略 2.1 资源…

Linux之进程管理

什么是进程 在linux中每个执行的程序都称为一个进程,每个进程都分配一个ID号(pid进程号)。每个进程都可能以两种方式存在,即前台和后天。前台进程就是用户目前的屏幕上可以进行操作的。后台进程则是实际在操作,但屏幕…

AD教程 (二十一)模块化布局规划

AD教程 (二十一)模块化布局规划 原理图是按照我们的功能模块去进行排布划分的 利用交叉选择模式分屏快速进行模块化布局 分屏,选中任意文档,右击,点击垂直分割 交叉选择模式,点击工具,交叉选…

【模拟电路】软件Circuit JS

一、模拟电路软件Circuit JS 二、Circuit JS软件配置 三、Circuit JS 软件 常见的快捷键 四、Circuit JS软件基础使用 五、Circuit JS软件使用讲解 欧姆定律电阻的串联和并联电容器的充放电过程电感器和实现理想超导的概念电容阻止电压的突变,电感阻止电流的突变LR…

基于SpringBoot的校园二手闲置交易平台

基于SpringBoot的校园二手闲置交易平台的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringBootMyBatis工具:IDEA/Ecilpse、Navicat、Maven 系统展示 主页 登录界面 管理员界面 摘要 本文基于Spring Boot框架设计并实现了一款…

buuctf-Misc 题目解答分解103-105

103.[GKCTF 2021]签到 追踪流发现类似flag 字符 f14g 下面有大量的是16进制字符 64306c455357644251306c6e51554e4a5a3046355355737764306c7154586c4a616b31355357704e65556c7154586c4a616b31355357704e65556c7154586c4a616b31355357704e65556c7154586c4a616b31355357704e655…

git rebase应用场景三

文章目录 git rebase应用场景三 git rebase应用场景三 在我们的开发分支中 假设我们修改一个文件 提交一个版本 再回到master分支 同时也去修改1.txt文件,提交一个版本 这样相当于master分支提交了一次,dev也提交了一次 然后回到dev分支 此时会报错…

【网络安全】upload靶场pass11-17思路

目录 Pass-11 Pass-12 Pass-13 Pass-14 Pass-15 Pass-16 Pass-17 🌈嗨!我是Filotimo__🌈。很高兴与大家相识,希望我的博客能对你有所帮助。 💡本文由Filotimo__✍️原创,首发于CSDN📚。 &#x…

gRPC之内置Trace

1、内置Trace grpc内置了客户端和服务端的请求追踪,基于golang.org/x/net/trace包实现,默认是开启状态,可以查看事 件和请求日志,对于基本的请求状态查看调试也是很有帮助的,客户端与服务端基本一致,这里…

Delphi6函数大全4-SysUtils.pas

Delphi6函数大全4-SysUtils.pas首部 function FormatFloat(const Format: string; Value: Extended): string; $[SysUtils.pas功能 返回浮点数类型以指定格式字符串Format转换成字符串说明 FormatFloat(,.00, 1234567890) 1,234,567,890.00参考 function …

element表格排序功能

官方展示 个人项目 可以分别对每一项数据进行筛选 注&#xff1a;筛选的数据不能是字符串类型必须是数字类型&#xff0c;否则筛选会乱排序 html <el-table :data"tableData" border height"600" style"width: 100%"><el-table-co…

实验六——cache模拟器实验

前言 本次实验的主要目的是熟悉cache的原理。加深对cache的映像规则、替换方法、cache命中与缺失的理解。通过实验对比分析映像规则对cache性能的影响。 实验内容一&#xff1a;熟悉模拟程序 阅读给出的cache模拟程序&#xff08;cachesimulator.cpp&#xff09;&#xff0c;…

Linux学习之系统编程1(关于读写系统函数)

写在前面&#xff1a; 我的Linux的学习之路非常坎坷。第一次学习Linux是在大一下的开学没多久&#xff0c;结果因为不会安装VMware就无疾而终了&#xff0c;可以说是没开始就失败了。第二次学习Linux是在大一下快放暑假&#xff08;那个时候刚刚过完考试周&#xff09;&#xf…