主讲老师:曾文轩
学习链接:第12讲:基于隐语的Vision Transformer框架
论文:【ICCV2023】MPCViT: Searching for Accurate and Efficient MPC-Friendly Vision Transformer with Heterogeneous Attention
隐语课程第12课,是一次论文及使用案例分享课,基于隐语实现Vision Transformer框架MPCViT。首先介绍了ViT在MPC环境下进行隐私推理的总体框架、研究动机以及算法流程。MPCViT借助SecretFlow-SPU搭建ViT框架,并测试不同模型架构的推理效率。介绍了SecretFlow-SPU的特点和功能,并展开讲解了基于SecretFlow和Jax的ViT框架搭建流程。最后,呈现了该工作的主要实验结果,效果均优于基线模型。
1. MPCViT:安全且高效的MPC友好型 Vision Transformer架构
-
客户端与服务器:
- 客户端: 提供输入数据。
- 服务器: 执行ViT模型计算。
-
数据处理流程:
- 输入嵌入(Input Embedding): 将输入图像分割成小块(Patch),并对其进行线性投影。
- 位置编码(Position Encoding): 为每个嵌入块添加位置信息。
- 层归一化(LayerNorm): 对嵌入数据进行归一化处理。
-
多头注意力机制(Multi-Head Attention):
- 线性变换(Linear): 将输入数据进行线性变换。
- 矩阵乘法(MatMul): 执行矩阵乘法操作。
- Softmax: 应用Softmax函数来计算注意力权重。
- MPC瓶颈(MPC Bottleneck): softmax是处理注意力计算中的瓶颈问题。
-
倒置瓶颈(Inverted Bottleneck):
- 线性层(Linear): 执行线性变换。
- GELU激活函数: 应用GELU激活函数来提高模型的非线性能力也是MPC瓶颈之一。
- 高维处理(High Dimension): 处理高维度数据。
-
多层感知机(MLP):
- 层归一化(LayerNorm): 对数据进行归一化处理。
- 线性层和GELU激活函数: 包含两个线性层和一个GELU激活函数
通过分析, ViT模型的主要通信瓶颈在于多头注意力机制中的Softmax和MLP中的GELU激活函数。其中Softmax函数中Max以Reciprocal操作又是主要耗时的算子。另外对比分析不同注意力机制变体的准确率和延迟。
已有一些工作发现不是所有的注意力都同等重要,是存在差别的。这个也是引出MPCViT的一个主要动机之一。
MPCViT的提出,主要是为了解决这两个主要问题: (1) 权衡模型准确率和推理延迟。(2)融合高准确率注意力机制和低延迟注意力机制。主要从四个角度出发进行算法优化:(1)设计合适的搜索空间;(2)MPC感知神经架构搜索;(3)基于延迟限制的架构参数二值化;(4)重训练异构注意力ViT。
搜索算法涉及三种粒度:(1)粗粒度(Transformer层级粒度);(2)中粒度(注意力头级粒度);(3)细粒度(注意力行级粒度 token级)。
对每一个注意力分配了alpha参数,搜索过程中自动确定应该保留哪个注意力,可微分的搜索算法效率可以得到保证。搜索完成之后,可以对alpha进行排序,值较大的注意力设置为高延迟类型,值较小的注意力设置为低延迟类型。
直接训练搜索后的异构注意力机制ViT会导致显著的准确率下降,采用多粒度的自蒸馏技术,将原始Softmax注意力机制ViT作为教师模型, 无需引入任何额外训练和推理开销。基于特征的token蒸馏,更加细粒度蒸馏,本文取最后一层特征蒸馏,效率更好。
2. 使用SecretFlow搭建ViT框架
介绍完MPCViT的原理后,接下来就是使用SecretFlow搭建ViT框架。关于SPU的介绍,可以参考我之前的笔记《隐语课程学习笔记8-理解密态引擎SPU框架》。
隐私推理协议的相关参数设置以及网络环境模拟。
基于Jax的ViT模型搭建,之前我们已经实践过NN模型以及GPT2模型的SPU实践,所以对这一块内容理解起来会比较快。使用Jax实现明文的模型结构,主要分为patch embedding搭建、注意力机制搭建、MLP模块搭建、Transformer模块搭建。
(1)Patch embedding搭建
(2)注意力机制搭建
(3)MLP模块搭建
(4)Transformer模块搭建:基于之前的各个子模块,来构建完整的Transformer模块。
(5)ViT模型的隐私推理:(1)初始化SPU环境;(2)设置输入变量维度(主要是编码维度、token数量、头数量等);(3)创建对象(注意力层等);(4)指定SPU设备以及执行函数,加密输入,在密文环境执行计算。
ViT模型推理操作流程,包括(1)配置Python环境及安装SPU,并激活Python环境;(2)配置并模拟通信网络环境(以WAN设置为例);(3)模拟MPC环境及协议;(4)执行隐私推理。
3. 该工作的主要实验结果,效果均优于基线模型
- 三种实验基线模型包括:Linformer (FaceBook 2020), THE-X (ACL 2021), MPCFormer (ICLR 2023);
- 三种实验数据集包括:CIFAR-10、CIFAR-100、Tiny-ImageNet;
- 实验结论:在Tiny-ImageNet数据集上,相比基线ViT、MPCFormer、THE-X,MPCViT具有更低的延迟和更好的准确率,延迟分别降低6.2×、2.9×和1.9×,提高了1.9%、1.3%和3.6%的准确率;
- 与仅使用交叉熵损失函数相比,基于logits和基于特征的知识蒸馏都使得ViT性能显著提升;
- 两种知识蒸馏损失函数的结合进一步提高了准确率,在更大的数据集上表现更明显;
- 在不同的λ、不同的数据集、不同的注意力头数量的设置下,搜索结构都可以表现出一致性;
- 架构参数可视化与层级可视化一致,更倾向于保留中间靠前的注意力。
总体来说,将隐语应用于实际的算法工作中,可行度较高,对开发者也更加友好,后续也会在实际学习和工作中多多使用。