ICLR 2023 spotlight
文章链接:https://arxiv.org/abs/2205.15480
代码链接:https://github.com/mertyg/post-hoc-cbm
一、概述
Post-hoc CBM(PCBM)也是CBM大家族中的一员,因此它的基本逻辑与CBM一致,就是在输入和输出之间构造一个bottleneck用于预测concepts。和其它很多文章类似,作者同样指出了CBM模型的缺点:
(i) dense annotation,即需要大量精细的标注;
(ii) accuracy-interpretability trade-off,即准确性与可解释性之间的取舍与权衡(尤其是在concepts not enough的情况下);
(iii) local intervention,即CBM只是针对个例进行干预,而不是提升模型本身的效果。
因此,本文提出PCBM,可以将任何网络转化为PCBM,且在不牺牲模型精度的同时保证可解释性;此外,当训练集中缺失annotation时,PCBM可以从其它数据集或使用多模态模型产生概念:“transfer concepts from other datasets or from natural descriptions of concepts via multimodal models”——在介绍CBM那篇文章的时候提到过——或者,引入一个residual modeling step来recover the original blackbox model's performance。此外,PCBM允许global model edits即全局的intervention,这种方法会比针对specific prediction的local intervention更加有效。
二、方法
We let be any pretrained backbone model, where is the size of the corres-ponding embedding space and is the input space. 可以是CLIP中的image encoder或者ResNet的倒数第二层(总之是一个编码器)。
建立PCBM需要以下几个步骤:
(i) Learning the Concept Subspace
为了学习concept representations,作者使用了CAVs的做法,首先定义了一个概念集合concept library ,其中 代表concepts的总数;concept library可以由domain expert定义或者从数据中自动学习(参考NeurIPS 2019, Towards automatic concept-based explanations.https://arxiv.org/abs/1902.03129)。
For each concept , we collect embeddings for the positive examples, denoted by the set , and negative examples .
作者训练了一个SVM对 与 分类,并计算对应的CAV(分类边界的法向量),并且与TCAV相同,CAV的学习并不局限于the data used to train the backbone model;将第 个concept对应的CAV记为 ,let denote the matrix of concept vectors. 的每一行就代表第 个concept对应的CAV 。
现在,我们有一个backbone model 作为encoder,一个由一系列CAVs组成的concept matrix 。此时给定输入 ,我们可以通过 将 投影到由 张成的向量空间,i.e., ,即 代表当前输入在第 个concept vector 方向上的长度(是一个scalar),直观来说就是当前输入 中包含概念 的程度(图中红色方框)👇
(ii) Leveraging multimodal models to learn concepts
前面提到CBM需要dense annotation,限制了实际应用。作者提出可以使用多模态模型比如CLIP来生成concept vector,具体来说,由于CLIP (Radford et al., 2021)具有一个image encoder和一个text encoder可以将二者编码到shared embedding space中,因此我们可以通过mapping the prompt using the text encoder to obtain the concept vectors;举例来说,如果我们想得到“strpes”这一concept对应的CAV但是又缺少标注好的数据,我们可以通过将“stripes”输入到CLIP的text encoder中,使用其编码后得到的向量作为CAV(其实就不叫CAV了,但是得到的这个向量也是类似CAV的一种用来表示概念的向量;为方便理解,此处索性就统一叫作CAV,但不要混淆),i.e. ;这样,对于每一个concept我们都有对应的语言表述,也都能相应地得到CAV,由此得到我们的multimodal concept bank .
Note:CAVs与Multimodal Models两种方法二选一,而不是将两种方法得到的CAV求并。
对于classification task,可以使用ConceptNet (Speer et al., 2017)来自动获取与类别相关的concepts,从而构建concept bank。
(iii) Learning the Interpretable Predictor
Let be an interpretable predictor. 可以选择线性模型或者决策树这种具有较强可解释性的模型,将预测得到的评分 映射为最终的类别 。通过优化以下式子来学习模型:
前面一项对应分类损失(如交叉熵),后面一项为正则项,用来限制predictor 的复杂度,并由类别和概念的数量进行归一化。在这项工作中作者使用的是sparse linear models。
(iv) Recovering the original model performance with residual modeling
即使我们拥有了一个相对丰富的概念子空间,概念很可能仍然不足以解决我们感兴趣的下游任务。对于这种情况,即PCBM与原始模型性能不匹配时,作者引入了从original embedding连接到最终决策层的残差部分,以保持原有模型的准确度,对应的模型为PCBM-h。此时,作者使用sequential的训练方式,首先训练 interpretable predictor ,然后固定concept bottleneck and the interpretable predictor并优化残差部分:
其中 是residual predictor,其输入是原始的不具有解释性的embedding,而最后的输出结果是综合了interpretable predictor的输出 以及residual predictor的输出 。可以将 视为原来interpretable predictor的一种补充;的输入是interpretable concept embeddings, 的输入是uninterpretable的original embeddings from backbone encoder. 模型的决策由 尽量解释,解释不了的由 来恢复原始精度。很显然,PCBM-h的精度一定是高于PCBM的。
Note:如果想观察interpretable predictor 的表现,那么就把residual predictor 网络中的参数全部置零从而drop掉这一支路,如果我们想得到一个黑盒模型,就把 网络中的参数全部置零。
三、实验及结果
(i) PCBMs achieve comparable performance to the original model
PCBMs获得了与黑盒模型comparable的性能,尤其是PCBM-h。
(ii) PCBMs achieve comparable performance to the original model
当提供的concepts not available or insufficient的时候,可以使用借助CLIP的text ecncoder产生的concept bank,发现CLIP自动生成的concept要比人为提供的概念标注更好。
(iii) Explaining Post-hoc CBMs
展示了针对于一个类别线性层中权重最大的三个concepts,在皮肤癌的例子中,模型考虑的concept与人类判断时考虑的因素一致。
(iv) Model editing
与基本的CBM对单个样本做干预(local intervention)不同,PCBM的一个优势就是允许global intervention从而直接提升模型整体的表现。当我们知道某些概念是错误的时候,可以通过剪枝(Prune)等操作优化模型。举个例子,如果训练集和测试集存在域偏差,比如,训练集中有很多“狗”的图片,但是在测试集中没有“狗”的图片,那么在训练阶段学习到的所有关于狗的概念都将无效,或者说对于测试集是“错误的概念”;此时我们可以采用以下三种strategies对模型进行修改:
(1) Prune: 在决策层将错误概念对应的权重置0,i.e., for a concept indexed by , we let ;
(2) PruneNormalize:在prune后rescale the concept weights,归一化可以缓解剪枝后较大权重造成的权值不平衡问题;
(3) Fine-tune (Oracle):在测试集上对整个模型进行微调,作为oracle。
可以发现PCBM进行PruneNormalize之后的增益较高,最接近oracle;而PCBM-h的增益很低。一个原因是PCBM可以通过Prune直接剪掉干扰预测的错误概念,但是由于PCBM-h的残差连接中仍包含来自错误概念的信息无法被去除,因此预测精度的提升不明显。
(v) User study
作者还进行了user-study,即测试集与训练集存在偏差时,让user自行选择一定数量的concepts进行prune,观察模型性能是否有提高,以验证模型能够良好的与人类进行交互;作者使用了三个实验设置作为对比:
(1) Random Pruning:随机对weights置零;
(2) Greedy pruning(Oracle):即prune掉与人类同样数量的concepts使得模型得到最佳增益;
(3) Fine-tune (Oracle):在测试集上微调。
Random prune发生了性能降低,而user prune可以明显改善模型性能,大概相当于80%的greedy prune增益与50%的fine-tune增益。
另一个现象是即使有残差连接但是仍然可以通过剪枝提高PCBM-h的性能,具体原因不知道。
最后是简单的discussion:
(1) 人类构建的concept bottleneck是否可以解决更大规模的任务是一个悬而未决的问题(例如ImageNet级别),因为会有information bottleneck的存在,精度concept定义insufficient,也是导致accuracy-interpretability之间有trade-off的原因所在。
(2) 以无监督的方式为模型寻找概念子空间是一个活跃的研究领域,它将有助于构建更加有用的、丰富的概念瓶颈。