Mixture-of-Depths: Dynamically allocating compute in transformer-based language models
公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
注:Transformer 由于其注意力机制(矩阵乘法),面临着二次计算量和高内存的挑战,导致推理速度缓慢、吞吐量低、难以处理长上下文等问题。Mamba 被提出来解决这些问题。Mamba 基于状态空间模型(SSM),其当前状态只与前一状态和当前输入有关。基于选择性状态空间的 Mamba 使用选择来压缩输入中有用的信息,从而缓解了 Transformer 面临的问题,能够以接近 Transformer 的性能处理更长的上下文。但由于其压缩信息的特性,其语境内学习(in-context learning)的能力不如Transformer。Jamba 通过交错使用 Transformer 以及 Mamba 层的块,以及额外的混合专家(MoE)来充分发挥 Transformer 和 Mamba 的优势。
(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模
(2024,Attention-Mamba,MoE 替换 MLP)Jamba:混合 Transformer-Mamba 语言模型
Mamba Explained (thegradient.pub)
不同于 Mamba,本文通过改进 Transfomer 来减少计算量,从而加快训练速度。
- (Sec. 1)该方法的基础是:在语言建模中,并非所有的 token 和序列都需要相同的时间或精力来准确预测。因此,transformer 应该通过减少不必要的计算来使用更小的总计算预算。
- (Sec. 3)该方法通过路由器为不同的 token 计算权重,从而决定哪些 token 要参与块的计算(自注意力和 MLP),哪些 token 要绕过块。
- (Sec. 4)初步分析表明,更频繁与块进行交互的 token 与具有更高熵的输出预测相关,这可能对应于更难以做出的预测。
- (Sec. 5)利用 MoD,通过将查询、键和值的路由分离,可将 transformer 扩展到 “长期记忆” 的领域。(类似于 Mamba)
目录
0. 摘要
1. 简介
2. 背景
3. 实现深度混合 Transformer
3.1. 定义计算预算
3.2. 围绕 transformer 块进行路由
3.3. 路由方案
3.4. 路由实现
3.5. 抽样
3.6. 训练方法
4. 结果
4.1. 训练,isoFLOP 比较
4.2. 自回归评估
4.3. 深度与专家的混合(MoDE)
5. 讨论
0. 摘要
基于 Transformer 的语言模型在输入序列上均匀分配 FLOPs(Floating point operations per second)。在这项工作中,我们展示了 transformers 可以学习动态地分配 FLOPs(或计算)给序列中的特定位置,优化模型深度中不同层次的序列上的分配。我们的方法通过限制每个给定层中可以参与自注意力和 MLP 计算的 token(𝑘)的数量来限制总计算预算。要处理的 token 由网络使用一个 top-k 路由机制(routing mechanism)确定。由于 𝑘 是预先定义的,这个简单的过程使用已知张量大小的静态计算图(graph),不像其他条件计算技术。然而,由于 𝑘 个 token 的标识(identities)是流动的,这种方法可以在时间和模型深度维度上非均匀地扩展 FLOPs。因此,总的计算支出完全是可预测的,但在 token 级别上是动态的和上下文敏感的。这种方式训练的模型不仅学会了动态分配计算,而且做得非常高效。这些模型与等效 FLOPs 和训练时间相匹配的基线性能相匹配,但每次前向传递需要的 FLOPs 只需一小部分,并且在后训练采样过程中,每步操作速度可提高 50% 以上。
1. 简介
不是所有的问题都需要相同的时间或精力来解决。类似地,在语言建模中,并非所有的 token 和序列都需要相同的时间或精力来准确预测。然而,transformer 模型在前向传播中为每个 token 消耗相同数量的计算。理想情况下,transformer 应该通过减少不必要的计算来使用更小的总计算预算。
条件计算(Conditional computation)是一种试图通过仅在需要时才计算来减少总计算的技术(Bengio等人,2016; Bengio,2013; Bengio等人,2013)。各种算法提供了何时以及如何使用多少计算的解决方案(Ainslie等人,2023; Bapna等人,2020; Fedus等人,2022)。然而,对这个具有挑战性的问题的一般形式化可能不适用于现有的硬件约束,因为它们倾向于引入动态计算图(Dehghani等人,2018; Graves,2016)。最有前景的条件计算方法可能是那些与我们当前的硬件堆栈相协调的方法,该方法优先考虑静态计算图和已知的张量大小,这些大小被选择为最大化硬件利用率。
在这里,我们考虑使用静态计算预算进行语言建模的问题,该预算可以比普通(vanilla) transformer 使用的预算少。网络必须学会如何动态分配可用的计算,通过在每个层中对每个 token 做出决策,决定从可用预算中的哪里消耗计算。在我们的实现中,总计算由用户在训练之前定义并保持不变,而不是网络的即时决策的函数。因此,硬件效率提升,例如减少内存占用量或减少每个前向传播的浮点运算量,可以被预期并提前利用。正如我们将展示的那样,这些收益可以在不牺牲整体性能的情况下获得。
我们利用了类似于混合专家(Mixture of Experts,MoE)transformer 的方法,其中在网络深度上进行动态的 token 级别路由决策。与 MoE 不同,我们选择要么将计算应用于 token (就像标准 transformer 的情况一样),要么通过残差连接传递它(保持不变并节省计算)。与 MoE 相比,我们将此路由应用于前向 MLP 和多头注意力。因此,这也影响到我们处理的键和查询,路由不仅决定更新哪些 token,还决定哪些 token 可用于关注。我们将这种策略称为深度混合(Mixture-of-Depths,MoD),利用 transformer 的深度来强调各个 token 如何通过不同数目的层或块 (参见图1)。
MoD 技术还允许在性能和速度之间进行权衡。一方面,可以训练一个 MoD transformer,相比普通 transformer,它在等效的训练 FLOPs(isoFLOP)上可以提高 1.5% 的最终的对数概率训练目标,而且训练所需的时间(wall-clock)相同。另一方面,可以训练一个 MoD transformer,它实现了与 isoFLOP 最优普通 transformer 相当的训练损失,并且每次前向传播使用的 FLOPs 只是一小部分(超过50%),因此更快。综合起来,这些结果意味着 MoD transformer 学会了智能地路由(即跳过不必要的计算),因为它们可以在每个前向传播中实现相等或更好的序列对数概率,尽管每次前向传播的 FLOPs 足迹较小。
2. 背景
Transformer 架构已经成为人工智能革命的动力,带来了前所未有的能力,但以昂贵的训练运行和服务程序为代价。这引起了人们对使 transformer 架构更有效的巨大兴趣。有望的方法之一是条件计算,通过学习机制确定何时以及如何消耗计算。这个术语是由 Bengio(2013)引入的,并在接下来的几年里进一步探讨了这个概念。
最近的大量工作已经为 transformer 开发了条件计算方法。其中一些工作专注于 “早期退出(early exiting)”,即学习何时结束对给定 token 的计算,允许 token 在做出退出决定后跳过任何剩余的 transformer 层(Elbayad等人,2019; Liu等人,2021; Schuster等人,2022)。在 MoD 中,与早期退出方法不同, token 可以跳过中间层,然后通过自注意力与通过所有中间层的 token 进行更新。我们推测这可能是一个有用的属性。
其他工作开发了一种迭代 transformer 层的方法,该方法具有适应性数量的共享权重的步骤(Dehghani等人,2018; Simoulin和Crabbé,2021)。Bolya 等人(2023)开发了一种方法,在训练后运行推理时选择要合并的 token 的方法,这不需要学习。Lei 等人(2023)在微调设置中利用条件计算,通过在适配器方法(He等人,2021)上构建来学习跳过一小部分冻结的预训练权重的块,而不是运行整个网络。
CoLT5(Ainslie等人,2023)使用条件路由来选择给定 token 是否通过每个前馈层的重路径或轻路径。此外,他们使用相同的路由机制来选择一个 token 是否将注意力集中在所有其他 token 或少数 token 上,就像 Guo 等人(2022)那样。与 MoD 类似,CoLT5 使用软 top k 来做出路由决策。然而,CoLT5 专注于编码器-解码器设置,因此需要解决 top k 操作的非因果性质给出有效的顺序解码的问题。相比之下,我们目前的 MoD 工作专注于仅解码器设置,因此我们提出了一个预测路由器,以实现 transformer 中条件计算的有效推理。
条件计算的一个成功的公式是由 Shazeer 等人(2017)介绍的 “专家混合” 层(MoE)。最初在 LSTM 的上下文中开发,后续的工作显示了 MoE 与 transformer 的引人注目的实证结果(Fedus等人,2022; Lepikhin等人,2020; Zoph等人,2022)。与其他尝试保留或消耗额外计算的条件计算方法不同,MoE transformer 使用条件逻辑将 token 路由到许多专家 MLP 之一,同时保持总计算消耗恒定。我们的深度混合方法可以被认为是使用 MoE transformer 的路由逻辑,但是与具有多个专家的 MoE 不同,MoD 部署了一个可以动态跳过的单个专家。
3. 实现 MoD Transformer
我们的高层策略如下:
- 通过限制序列中可以参与块计算(即自注意力和后续 MLP)的 token 数量,设置小于等效普通 transformer 的静态计算预算。 例如,虽然普通 Transformer 可能允许序列中的所有 token 参与自注意力,但我们可能会将数量限制为序列中 token 的 50%。 参见第 3.1 节。
- 使用每块(per-block)路由器为每个 token 分配标量权重,这表示路由器对该 token 参与块计算或绕过它的偏好。 参见第 3.2 节。
- 确定 top-𝑘 标量权重(每个序列、每个块)以选择将要参与块计算的那些 token。 由于 𝑘 个 token 将参与块的计算,因此计算图和张量大小在整个训练过程中保持静态; 这只是由路由器确定的动态且上下文相关的 token 参与。 参见第 3.3 节。
- 然后我们在 3.5 节中讨论训练后采样时的一些复杂情况。
3.1. 定义计算预算
为了强制执行每个前向传递的总计算预算,我们利用容量的概念,定义构成给定计算输入的 token 总数(例如,参与自注意力的 token 、MoE transformer 中的给定专家等)。 例如,每个普通 transformer 块中的自注意力和 MLP 的容量为 T——序列和 batch 中的 token 总数。 另一方面,MoE transformer 使用的每个专家 MLP 的容量小于 T,以便更均匀地为每个专家分配总计算量。 但是,由于它们每个块使用多个专家,因此它们的总容量大约等于普通 transformer 的容量。
一般来说,决定使用条件计算的 transformer 的总 FLOPs 是由 token 容量决定的,而不是任何路由决策的结果。 这是因为静态图考虑了最坏情况的决策; 例如,计算的输入将被填充到其容量,即使相对较少的 token 最终路由到它。如果超出容量,token 将从计算中删除。 通过降低计算容量,我们可以实现与普通 transformer 相比,每次前向传递使用更小的计算预算的目标。 然而,随意使用较小的计算预算将导致性能下降。 我们假设某些 token 可能不需要像其他 token 那样多的处理,并且这些 token 可以通过学习来识别。 因此,如果网络学会选择正确的 token 来填充其容量,那么它可能会保持其性能。 下面我们描述可用于此目的的路由方案。
3.2. 围绕 transformer 块进行路由
我们考虑将 token 路由到两个计算路径之一的设置:(1) 自注意力和 MLP 块,以及 (2) 残差连接。 后者的计算成本较低,并且产生的块输出完全由其输入值决定。 前一种路径的计算成本很高。
如果我们将路径 (1) 的容量设置为小于 𝑇(序列和 batch 中的 token 总数),则每次前向传递的 FLOPs 总数将少于普通 transformer 中的总数。 例如,如果我们将块的容量设置为 𝑇/2(即,普通 transformer 中 token 数量的一半),那么自注意力期间的 QK 矩阵乘法将变为普通 transformer 中 FLOPs 的 25% (( 𝑇/2)^2 与 𝑇^2)。类似的计算可以确定 MLP 的 FLOPs 节省。
直觉上,每次前向传递的总 FLOPs 随着我们缩小块容量的程度而减少(并且完成前向传递所需的时间也减少)。然而,下游性能也会受到我们缩小块容量的程度以及我们实现的路由算法的影响。
在一个极端情况下,如果我们将每个块的容量保持在 𝑇,并将每个 token 都路由到(而不是绕过)每个块,那么我们就恢复了一个普通的 Transformer。在另一个极端情况下,如果我们将每个块的容量设置为 0,并将所有 token 都绕过每个块路由,那么我们得到的是一个非常快速的模型,它不涉及 Transformer 的绝大部分参数,并且无疑会在下游性能方面表现不佳。我们假设在这两个极端之间的某个地方有一个最优模型,它比普通 Transformer 更快,并且在执行步骤更快的同时性能也更好。
3.3. 路由方案
从朴素的角度来看,可以利用随机性来路由 token,类似于层或块的 “丢弃”。我们将这个路由方案作为一种控制,将显示它明显地相对于普通 Transformer 的性能不足。
我们假设学习的路由更可取。直觉上,网络应该能够学会哪些 token 需要比其他 token 更多或更少的处理。如果我们正确地认为 Transformers 经常花费比它们做出预测所需的计算要多,那么对于我们能够缩小每个块的容量的程度,以及因此围绕每个块我们能够负担得起将多少 token 路由,这是一个经验性问题。
我们考虑了两种学习路由方案(见图 2):token 选择和专家选择。
- 在 token 选择路由中,路由器产生跨计算路径的每个 token 的概率分布(例如,在 MoE Transformers 中跨专家标识)。然后, token 被分配到它们偏好的路径——即具有最高概率的路径——并且辅助损失确保所有 token 不会收敛到同一路径。token 选择路由可能存在负载平衡问题,因为不能保证 token 在可能的路径之间适当地分配。
- “专家选择路由” 将这个方法颠倒过来:不是让 token 选择它们偏好的路径,而是让每个路径根据 token 的偏好选择前 𝑘 个 token。这确保了完美的负载平衡,因为 𝑘 个 token 被保证被分配到每个路径。然而,这可能导致一些 token 过度或不足处理,因为一些 token 可能是多个路径的前 𝑘 个,也可能不是任何路径的前 𝑘 个。
我们决定利用专家选择路由的几个原因。
- 首先,它省去了辅助平衡损失的需要。
- 其次,由于前 𝑘 操作取决于路由器权重的大小,这种路由方案允许相对路由权重来帮助确定哪些 token 最需要块的计算;路由器可以通过适当设置它们的权重来确保最关键的 token 在前 𝑘 个之中,而这在 token 选择路由方案中是不可能的。 对于我们的具体用例,其中一个计算路径基本上是一个空(null)操作,重要的是将重要的 token 路由远离空操作。
- 第三,因为我们只通过两个路径路由,一个单独的前 𝑘 操作可以有效地将 token 分成两个互斥集合,每个计算路径一个,防止了上述提到的过度或不足处理问题。
3.4. 路由实现
每个 token 都经过路由器处理以产生一个标量权重,然后使用前 𝑘 个权重来选择通过 Transformer 块的 token 的标识,该块包括自注意力和 MLP。 假设我们对于给定层 𝑙 的长度为 𝑆 的序列有一组 token 嵌入,即 𝑋^𝑙 = {𝑥^𝑙_𝑖 | 𝑖是整数,1 ≤ 𝑖 ≤ 𝑆}。对于给定的 token 嵌入,路由器权重是通过线性投影产生的标量,
我们的目标是使用这些路由器权重来确定每个 token 在块中计算的输出。假设 𝑃_𝛽 (𝑅^𝑙) 是路由器权重 𝑅_𝑙 集合的 𝛽 百分位数,其中 𝛽 = 1 − 𝐶/𝑆,而 𝐶 是用户定义的每 batch 元素的容量(一个小于 𝑆 的整数,定义了将由给定函数处理的序列中的 token 数)。对于给定的 token,块的输出是:
其中,˜𝑋^𝑙 是具有路由值 𝑟^𝑙_𝑖 > 𝑃_𝛽 (𝑅^𝑙)(即,“top-k” token)的 token 集合,而 𝑓 包括自注意力和 MLP。请注意,给定 token 𝑥^(𝑙+1)_𝑖 的输出可能取决于其他 token 𝑥^𝑙_(𝑖≠𝑗),因为存在自注意力操作。˜𝑋^𝑙 的基底是 𝐶(或 𝑘):用户定义的容量。因此,与基线相比,MoD Transformer 在块计算的输入 𝑓 中包含的 token 数较少(𝐶 < 𝑆),从而使自注意力和 MLP 更加节省成本。
值得注意的是,我们将函数 𝑓 的输出乘以路由器权重。这将路由器权重置于 “梯度路径”上,因此它们在语言建模任务的过程中受到梯度下降的影响(我们尝试了一些版本,其中路由器权重也包括在绕过块计算的那些 token 的计算路径中,但仅将路由器权重包括在那些不绕过块计算的 token 的计算路径中,似乎足够简单且易于实现)。
3.5. 抽样
尽管专家选择路由具有许多优点,但它有一个明显的问题:top-𝑘 操作是非因果的。这意味着一个给定 token 的路由权重是否在序列的 top-𝑘 中取决于它后面的 token 的路由权重的值,而我们在自回归抽样时无法访问这些值。
我们测试了两种方法来解决这个问题。
- 第一种方法引入了一个简单的辅助损失,该损失在经验上影响了主要的语言建模目标约 0.2−0.3%,但允许我们从模型自回归地进行抽样。我们使用了一个二元交叉熵损失,其中路由器的输出提供了 logits,而这些 logits 的 top-𝑘 选择提供了目标(即,如果一个 token 位于 top-𝑘 中则为 1,否则为 0)。直觉上,这种损失使路由器的输出的 sigmoid 集中在 0.5 左右;被选择在 top-k 中的那些 token 被迫产生高于 0.5 的路由器输出,而那些没有被选在 top-k 中的 token 则被迫产生低于 0.5 的路由器输出。
- 第二种方法引入了一个小型的辅助 MLP 预测器(类似于第二个路由器),它接收与路由器相同的输入(带有停止梯度),但其输出是预测该 token 是否会成为序列中的 top-𝑘。该方法不影响语言建模目标,并且在经验上不会显著影响步骤速度。 有了这些新方法,我们可以通过根据路由器的输出选择将 token 路由到或绕过块,而这不依赖于未来 token 的任何信息来自动回归地进行抽样。我们提供了实证证据表明,这是一个相对简单的辅助任务,可以迅速达到 99% 的准确率。
3.6. 训练方法
所有模型使用相同的基本超参数配置(例如,余弦调度等于 1× 训练步骤,128 批次大小,2048 序列长度),除了更改层数,头数和嵌入大小以在 isoFLOP 分析期间生成不同大小的模型。
4. 结果
4.1. 训练,isoFLOP 比较
我们首先使用相对较小的 FLOPs 预算(6e18)训练模型,以确定最佳的超参数(见图 3)。总的来说,我们发现 MoD transformer 将基线 isoFLOP 曲线 “向下和向右” 拖动。也就是说,最佳的 MoD transformer 比最佳的基线具有更低的损失,并且有更多的参数。这种效果的一个幸运的结果是存在较小的 MoD 模型,虽然它们本身对于它们的超参数设置而言不是 isoFLOP 最优的,但是它们的性能与最佳的基线模型一样好甚至更好,同时训练更快。例如,一个 220M 参数的 MoD(图 3 模型#3)变种稍微优于 isoFLOP 最优的基线(也是 220M,图 3 模型 #1),但在训练过程中却快了多达 60%。重要的是,当在等效硬件上运行时,这两个模型变体的训练时间大致相同(图 3)。
我们测试了每个块或每隔一个块进行路由,使用总序列的 12.5% 到 95% 的容量。虽然每隔一个块进行路由对于获得强大的性能至关重要,但我们发现激进的容量减少效果最好(当将容量减少到总序列的 12.5% 时,对应于 87.5% 的 token 绕过块,观察到渐进性的改善,性能在此点之后开始下降)。因此,似乎只要能频繁进行完全容量的自注意力和 MLP 计算,网络就能够承受显著的容量减少。
学习的路由是至关重要的,因为使用随机路由的 MoD transformer (使用从高斯分布中采样的路由器权重进行的 top-𝑘 操作实现)的性能远远低于基线和普通 MoD transformer (图 3)。
图 4 显示了 6e18、2e19 和 1e20 总 FLOPs 的 isoFLOP 分析。对于这些较大的 FLOPs 预算, FLOPs 最优的 MoD transformer 比基线具有更多的参数的趋势继续存在。值得注意的是,存在 MoD 变体,它们在每次前向传递中的 FLOPs 数比 isoFLOP 最优基线要少得多(在图 4 中,我们描述了标准化的每个前向传递的 FLOPs ,而不是实际的 wall-clock 步长,但从我们的实验中,这两者之间存在紧密的相关性。可以产生类似的图,显示相对的 wall-clock 步长,并且具有相同的基本趋势)。
训练速度的增益有两个来源:
- 首先,因为一些 token 绕过块,MoD transformer 中每个参数的 FLOPs 小于基线。因此,对于给定的模型大小,Transformer 每个前向传递需要的 FLOPs 较少。
- 其次,由于 isoFLOP 最优的 MoD transformer 不仅更大,而且达到了比 isoFLOP 最优的基线更低的损失,因此存在比 isoFLOP 最优基线性能相似或更好的较小 MoD 变体,这些变体因为更小,所以训练更快。因此,存在 MoD transformer,其性能与 isoFLOP 最优基线一样好,并且训练更快,因为它们每个参数使用的 FLOPs 较少,而且它们使用的参数较少。
图 4 还揭示了另一个重要发现:最佳的 MoD transformer 是每个前向传递使用与 isoFLOP 最优基线相同 FLOPs 数的 transformer。这一发现使人们能够直接预测对于给定的 isoFLOP 训练预算,哪种大小的 MoD transformer 将达到最佳性能:只需调整给定 MoD 配置(即,容量和路由频率)的模型大小,以生成每个前向传递使用的 FLOPs 数与 isoFLOP 最优基线相同的模型,并且它们将具有该配置的最佳执行 MoD 变体。经验证明,增加深度比增加宽度更好地增加模型的 FLOPs。
尽管每个前向传递的 FLOPs 决定了哪个模型将是 isoFLOP 最优的,但它并不能预测最佳损失是否会优于基线(见图 3)。也就是说,最佳容量似乎是经验性可确定的。我们发现最好使用每个其他块的 12.5% 容量块。 我们注意到,与等效大小的基线模型相比,MoD transformer 在较大的尺寸上可内存节省,其中一些变体需要较少的总设备(即,较小的 TPU 拓扑)。我们没有进行深入研究,但我们预计随着模型规模的扩大,这些节省可能是选择要训练的模型变体时的重要考虑因素,并且在自回归抽样期间对 KV 缓存大小可能产生显着的正面影响。
图 5 显示了使用交错路由块训练的 MoD transformer 的路由决策。尽管激进的绕过块,但与基线相比,transformer 能够实现性能改进。我们观察到可能需要进一步研究的模式;换句话说,一些 token 似乎与 transformer 的深度一起与每个块进行交互,而其他 token 则在可能时决定绕过块。初步分析表明,更频繁与块进行交互的 token 与具有更高熵的输出预测相关,这可能对应于更难以做出的预测。
4.2. 自回归评估
我们在自回归抽样期间评估了 MoD 变体(见图 6)。每个模型在完全相同的保留数据上进行了测试,包括 256000 个序列(500M 个 token)。当从 top-𝑘 路由方法切换到基于预测器的路由方法时,我们观察到性能几乎没有下降。与训练设置一样,存在比 isoFLOP 最优基线更好的性能的 MoD 变体,同时每个前向传递需要的 FLOPs 更少。这些结果表明,MoD transformer 提供的计算节省应该可以超越训练设置。
4.3. 深度与专家的混合(MoDE)
MoD 技术可以自然地与 MoE 模型(一起组成 MoDE 模型)以及普通 transformer 集成。在图 7 中,我们呈现了与 MoE 相结合 MoD 提供的性能改进。我们尝试了两种变体:
- 在分阶段的 MoDE 中,在自注意步骤之前,将 token 路由到的块周围(绕过)或块内(通过)
- 在集成的 MoDE 中,通过在常规 MLP 专家之间集成 “空操作” 专家来实现 MoD 路由。
前者的优点在于它允许 token 跳过自注意步骤,而后者的优点在于它简化了路由机制。我们注意到,在集成 MoDE 机制中实施 MoDE 明显优于简单地减少常规 MoE 模型中专家的容量,并依靠 token 丢弃来实施剩余路由。我们认为这是因为在集成的 MoDE 机制中,token 明确地学习选择绕过专家的剩余路径,而不是在容量减少时选择专家但在实施时被丢弃。
5. 讨论
MoE transformer 经验性地表明,可以通过每个前向传递使用更少的 FLOPs 来改进 isoFLOP 最优基线性能。这意味着,对于给定的训练 FLOPs 预算,我们可以训练比其基线对应物更快且性能更好的模型。以前,要训练比 isoFLOP 最优模型更快且性能至少相同的模型,人们必须使用过剩的计算来过度训练较小的模型(值得注意的是,这种过度训练技术仍然适用于 MoD transformer,并且速度增益应该相乘)。
尽管 MoD transformer 每个前向传递需要的 FLOPs 较少,但不能不加选择地使用 FLOPs 。 相反,使用学习的路由决策(就像在 MoE transformer 中一样)来确定一个 token 应该参与自注意力和随后的 MLP(需要 FLOPs),或不参与(节省 FLOPs)是至关重要的。然后,我们可以通过使模型更大或更长时间地训练来使用任何保存的 FLOPs。我们的结果表明, FLOPs 在普通 transformer 模型中可能被低效地使用,并且可能有更有效的方式来使用它们。
学习的路由机制有时是非因果的;也就是说,未来的信息被用来确定给定 token 的路由决策。这通常对于 top-k 路由机制是正确的,因为它们省略了辅助平衡损失的需要。但是,在后期训练的自回归抽样中,使用关于未来 token 标识的信息来确定路由决策是不可能的。在这项工作中,我们表明可以在训练期间成功使用 top-k 路由方案,但在后续的自回归抽样中不需要它。简单的辅助分类器或路由器上的辅助损失就足以学习 top-𝑘 路由决策,使其能够在自回归抽样期间模拟 top-𝑘 决策,几乎没有性能下降。
直观地,一个 token 可能会学习绕过块,因为在该步骤中正在进行的预测更容易,因此不需要太多的计算。但是,这种策略无疑不是网络学到的全部。如果一个 token 在某个块中不参与自注意,则以后的 token 也将无法关注它。因此,决定 token 路由或不路由都会通过因果自注意力影响当前步骤的预测和未来的预测,网络如何平衡这些影响是由它们对整体语言建模目标的影响所指导的。
这一见解为 MoD 变体打开了一扇门,可以将查询、键和值的路由分离。例如,也许一个 token 愿意成为给定自注意计算中的查询,但不是键。人们可以想象将这个想法进一步扩展到 “长期记忆” 的领域:也许有些 token 无论何时都作为键是非常有价值的,而不管它们作为查询是否有用。学习的路由可能是一个强大的机制,用于决定这些 token 可能是什么,可能将它们引导到在未来自注意中可用的长期内存缓冲区。这种方法的一个优点是,token 在 “记忆编码” 的那一刻决定是否将来检索它们。这比在将来的每一步中对整个内存缓冲区执行完整的基于内容的查找更具计算效率,并且可能是朝着大幅增加用于做出预测的上下文长度的一步。
与在相同计算(通常是 MLP)之间有效地路由的 MoE transformer 不同,MoD transformer 展示了在不同类型的计算之间路由的价值。在这项工作中,这些类型是传统的 transformer 块或空计算(在功能上等同于乘以零)。但是,人们可以想象将这个想法进一步扩展到更多类型的计算之间。例如,也许一些 token 被路由到 “内存查找” 功能,而其他 token 被路由到 “工具使用” 功能。总的来说,我们部署的路由机制为调整网络可用的计算类型和它们的相对成本(总 FLOPs )提供了一个旋钮;如果想引入昂贵的计算,那么可以通过将其容量设置为某个小量来抵消,因此,只将少量 token 路由到它。
总而言之,MoD transformer 是调整模型每个前向传递的计算(因此推理时间)的另一种工具。用于实现 MoD 的机制也是通用的,并且为许多扩展和与其他技术的集成打开了大门,例如 MoE。