论文地址:https://arxiv.org/pdf/2501.00663v1
本文介绍了一篇由 Google Research 发表的关于新型神经网络架构 Titans 的论文,该架构旨在解决传统 Transformer 在处理长序列时的局限性。以下是对论文的详细解读,并结合原文图片进行说明:
1. 背景与动机
1.1 Transformer 的局限性
- 计算复杂度高: Transformer 的注意力机制基于 softmax 计算,其时间复杂度和空间复杂度与序列长度呈二次方关系,导致在处理长序列时计算成本高昂。
- 原文: “尽管 Transformer 在序列建模中表现出色,但其主要构建模块——注意力模块——作为联想记忆块运作,计算成对查询(搜索信号)和键(上下文)之间的相似度。这种对依赖关系的准确建模带来了二次方的时间和内存复杂度,限制了模型只能处理固定长度的上下文。”
- 上下文窗口有限: 由于上述计算瓶颈,Transformer 难以处理需要长距离依赖关系的复杂任务,例如语言建模、视频理解和长时序预测。
1.2 线性 Transformer 的不足
- 性能下降: 为了解决计算复杂度问题,研究人员提出了线性 Transformer,用核函数替代 softmax,显著降低了内存消耗,但性能不如标准 Transformer。
- 原文: “尽管线性 Transformer 提高了效率并能够扩展到更长的上下文,但与 Transformer 相比,其性能并不具有竞争力,因为核技巧使模型成为线性循环网络,数据被压缩为矩阵值状态。”
- 信息压缩问题: 线性 Transformer 将历史数据压缩到固定大小的矩阵中,难以有效处理长距离依赖关系。
1.3 人类记忆的启发
- 多层次记忆系统: 人类记忆系统由多个功能不同的子系统组成,例如短期记忆、工作记忆和长期记忆。
- 原文: “事实上,记忆是一个系统的联合体——例如,短期、短期和长期记忆——每个系统都有不同的功能,不同的神经结构,并且每个系统都能够独立运作。”
- 记忆与学习的关系: 记忆是学习的基础,良好的记忆系统对于有效学习至关重要。
2. Titans 的核心思想
基于以上背景,Titans 提出了一种新的神经网络架构,旨在结合短期记忆和长期记忆的优势,以更有效地处理长序列数据。
2.1 神经长时记忆模块
- 设计理念:
- 模拟人类长期记忆系统,能够在测试时学习如何记忆/存储数据。
- 引入“惊喜”度量指标,衡量输入数据与历史数据的差异性,惊喜值越大的数据越容易被记忆。
- 原文: “我们设计这个记忆模块,以便违反期望的事件(令人惊讶的)更令人难忘。为此,我们使用神经网络相对于输入的梯度作为联想记忆损失中的惊喜度量。”
- 采用衰减机制,根据记忆容量和数据惊喜量来管理记忆,避免内存溢出。
- 原文: “为了更好地处理有限的内存,我们提出了一个衰减机制,它考虑了内存大小和数据惊喜量的比例,从而实现更好的内存管理。”
- 训练过程:
- 将训练视为在线学习问题,将过去的信息压缩到长时记忆模块的参数中。
- 使用带有动量和权重衰减的梯度下降法进行优化。
- 通过张量化 mini-batch 梯度下降,并使用更多矩阵乘法操作,实现了快速并行化训练。
- 原文: “我们展示了这种衰减机制实际上是现代循环模型中遗忘机制的泛化。我们发现该机制等同于使用小批量梯度下降、动量和权重衰减优化元神经网络。在此基础上,我们提出了一个快速且可并行化的算法来训练我们的深度神经长时记忆。”
2.2 Titans 架构
Titans 架构由三个主要模块组成:
(1) 核心模块 (Core):
- 负责处理数据的主要流程,使用有限窗口大小的注意力机制。
- 相当于短期记忆模块,关注当前上下文窗口。
(2) 长时记忆模块 (Long-term Memory):
- 我们的神经长时记忆模块,负责存储/记住长距离的历史信息。
(3) 持久记忆模块 (Persistent Memory):
- 一组可学习的、但与数据无关的参数,用于编码有关任务的元知识。
- 相当于任务相关的记忆,存储任务相关的知识。
图 1: 神经记忆的训练过程可以并行化,并使用矩阵乘法操作。
3. 如何将记忆融入架构?
Titans 提出了三种不同的变体来将记忆模块融入到深度学习架构中:
3.1 记忆作为上下文 (Memory as a Context, MAC)
-
架构设计:
- 将长序列分割成固定大小的片段,将当前片段作为当前上下文,其前一个片段作为历史信息。
- 使用当前上下文作为查询,从长时记忆模块中检索相应的信息。
- 将检索到的历史信息与持久记忆参数一起作为输入序列输入到注意力模块中。
- 原文: “接下来,我们使用这个历史信息以及我们的持久记忆参数作为注意力模块的输入序列。”
- 注意力模块决定哪些信息应该存储在长时记忆中。
- 架构图示:
- 图 2: 记忆作为上下文 (MAC) 架构。该架构包括三个分支:(1) 核心,(2) 上下文(长时)记忆,和 (3) 持久记忆。核心分支将相应的长时和持久记忆与输入序列连接起来。接下来,注意力在序列上执行,并决定哪些信息应该存储在长时记忆中。在测试时,对应于上下文记忆的参数仍在学习,对应于核心分支的参数负责上下文内学习,而持久记忆的参数负责存储有关任务的知识,因此是固定的。
-
优点:
- 注意力模块能够决定是否需要长时记忆信息。
- 注意力模块帮助长时记忆只存储当前上下文中有用的信息,避免内存溢出。
- 持久记忆参数固定,注意力模块权重进行上下文内学习,长时记忆模块在测试时仍在学习。
3.2 门控记忆 (Gated Memory, MAG)
-
架构设计:
- 一个分支直接使用输入数据更新长时记忆,另一个分支使用滑动窗口注意力 (SWA)。
- 原文: “在下一个变体中,在一个分支中,我们直接使用输入数据来更新长时记忆,在第二个分支中,我们使用滑动窗口注意力 (SWA)。”
- 使用门控机制将记忆与核心分支结合起来。
- 架构图示:
- 图 3: 不同变体的 Titans 的注意力掩码。
- (b) 记忆作为门控 (MAG)。我们使用滑动窗口注意力 (SWA) 作为短期记忆,我们的神经记忆模块作为长期记忆,通过门控机制进行组合。
- 一个分支直接使用输入数据更新长时记忆,另一个分支使用滑动窗口注意力 (SWA)。
-
特点:
- SWA 作为精确的短期记忆,神经记忆模块作为模型的渐逝记忆。
- 可以看作是多头架构,其中头的结构不同。
3.3 记忆作为层 (Memory as a Layer, MAL)
-
架构设计:
- 将神经记忆作为深度神经网络的一层。
- 原文: “最后一个变体将神经记忆作为深度神经网络的一层 (MAL)。”
- 类似于混合模型,将循环模型与全连接或滑动窗口注意力堆叠在一起。
- 架构图示:
- 图 4: 记忆作为层 (MAL) 架构。在该架构中,记忆层负责在注意力模块之前压缩过去和当前上下文。
- 将神经记忆作为深度神经网络的一层。
-
缺点:
- 模型的性能受限于每一层的性能,无法充分利用注意力与神经记忆模块的互补数据处理能力。
4. 实验结果
4.1 语言建模与常识推理
- Titans 在语言建模和常识推理任务中表现优于所有基线模型,包括 Transformer 和现代线性循环模型。
- 混合模型 (MAC、MAG、MAL) 性能优于 Samba (Mamba + 注意力) 和 Gated DeltaNet-H2 (Gated DeltaNet + 注意力)。
- MAC 在处理数据中的长距离依赖关系时表现更好。
4.2 针刺任务 (Needle in a Haystack)
- Titans 在针刺任务中表现出色,能够有效处理超长序列,检索信息的能力优于基线模型。
- 原文: “我们认为,Titans 的这种优异性能归功于 Titans 与现有序列模型之间的三个关键区别:…”
- 神经记忆模块能够更好地管理内存容量,并通过动量和遗忘机制 (权重衰减) 保持性能稳定。
4.3 BABILong 基准测试
- 在 BABILong 基准测试中,Titans 再次展现出强大的长距离推理能力,在少样本和微调设置下均优于所有基线模型,包括 GPT-4。
4.4 时间序列预测
- 神经记忆模块在时间序列预测任务中也表现出色,优于基于 Mamba、线性模型和 Transformer 的架构。
4.5 DNA 建模
- 神经记忆模块在基因组学下游任务中与最先进架构相比具有竞争力。
4.6 效率
- 神经记忆模块的训练吞吐量略低于 Mamba2 和 Gated DeltaNet,主要原因是其具有更深的记忆和更复杂的转换过程。
- Titans (MAL) 的训练速度比基线模型和记忆模块更快,主要得益于 FlashAttention 的高度优化内核。
4.7 消融研究
- 神经记忆模块的所有组件均对性能有积极贡献,其中权重衰减、动量、卷积和持久记忆贡献最大。
5. 结论
Titans 架构通过引入神经长时记忆模块,弥补了 Transformer 在处理长序列时的不足。其主要优势在于:
- 更强的长距离依赖建模能力: 神经记忆模块能够有效存储长距离信息,并进行在线学习。
- 更好的内存管理: 动量和遗忘机制确保了内存的有效利用。
- 可扩展性: Titans 可以扩展到超过 2M 的上下文窗口大小,同时保持较高的准确性。
总的来说,Titans 为长序列建模提供了一种新的思路,在多个任务上展现出强大的性能。