摘要
记忆镶嵌是由多个关联记忆网络协同工作来完成感兴趣的预测任务。与transformer类似,记忆镶嵌具有组合能力和上下文学习能力。与transformer不同,记忆镶嵌以相对透明的方式实现这些能力。该研究在玩具示例上展示了这些能力,并且还表明记忆镶嵌在中等规模语言建模任务上的表现与transformer一样好或更好。
1 简介
本文介绍了一种学习系统架构,记忆镶嵌,其中多个关联记忆协同工作来执行感兴趣的预测任务。这种系统与记忆网络密切相关,尽管有显著差异,但与transformer相似。与transformer一样,记忆镶嵌具有机器学习系统长期难以实现的一些分解和组合能力。与transformer的内部机制难以破译不同,记忆镶嵌以相对透明的方式实现这些能力。
本工作的三个主要贡献是:(a)认识并利用平滑关联记忆和self-attention之间的相似性,(b)识别并说明预测解耦原则,解释训练如何以有趣的方式分解整体任务,以及©表明与解码transformer相比,这种相对透明的架构在语言建模任务上的性能相当。
第2节描述了基本架构并概述了其结果。第3节说明了预测解耦原则。第4节将这些思想扩展到完全形成的记忆镶嵌。第5节报告了中等规模语言建模实验。
2 记忆
关联记忆 一般来说,关联记忆是一种可以存储键值对并在给定相应键时检索值的设备。这个定义省略了处理重复键和近似匹配的重要细节。出于该研究的目的,键和值都应该是 R d \mathbb{R}^d Rd中的向量。然后,检索过程可以表示为查询键 k k k和所有存储对 ( k 1 , v 1 ) . . . ( k n , v n ) (k_1,v_1)...(k_n,v_n) (k1,v1)...(kn,vn)的函数。
R
d
→
R
d
\mathbb{R}^d \rightarrow \mathbb{R}^d
Rd→Rd
k
↦
f
(
k
;
{
(
k
1
,
v
1
)
.
.
.
(
k
n
,
v
n
)
}
)
k \mapsto f(k; \{(k_1, v_1)...(k_n, v_n)\})
k↦f(k;{(k1,v1)...(kn,vn)})
除了涉及重复键的情况外,关联记忆存储键值对时不考虑它们的时间顺序。因此,可以假设检索函数对存储对的任何排列是不变的。这种可交换性表明,该研究也可以将关联记忆视为一个设备,它根据键值对样本 ( k 1 , v 1 ) . . . ( k n , v n ) (k_1,v_1)...(k_n,v_n) (k1,v1)...(kn,vn)估计条件概率分布 P ( V ∣ K ) P(V|K) P(V∣K)。然后,检索函数是这个估计分布的条件期望:
f ( k ; { ( k 1 , v 1 ) . . . ( k n , v n ) } ) = E ( V ∣ K = k ) f(k; \{(k_1, v_1)...(k_n, v_n)\}) = \mathbb{E}(V | K = k) f(k;{(k1,v1)...(kn,vn)})=E(V∣K=k). (1)
这样的条件期望可以用高斯核平滑构造,
f ( k ; { ( k 1 , v 1 ) . . . ( k n , v n ) } ) = ∑ i = 1 n 1 Z e − β ∣ ∣ k − k i ∣ ∣ 2 v i w i t h Z = ∑ i = 1 n e − β ∣ ∣ k − k i ∣ ∣ 2 f(k; \{(k_1, v_1)...(k_n, v_n)\}) = \sum_{i=1}^n \frac{1}{Z} e^{-\beta||k-k_i||^2} v_i \quad with \quad Z = \sum_{i=1}^n e^{-\beta||k-k_i||^2} f(k;{(k1,v1)...(kn,vn)})=∑i=1nZ1e−β∣∣k−ki∣∣2viwithZ=∑i=1ne−β∣∣k−ki∣∣2. (2)
当所有键向量 k i k_i ki共享相同的平方范数时,核平滑和注意力之间的紧密联系[Bahdanau et al., 2015]特别明显,因为表达式(2)可以简化为
f ( k ; { ( k 1 , v 1 ) . . . ( k n , v n ) } ) = ∑ i = 1 n e β k ⊤ k i ∑ j = 1 n e β k ⊤ k j v i f(k; \{(k_1, v_1)...(k_n, v_n)\}) = \sum_{i=1}^n \frac{e^{\beta k^{\top} k_i}}{\sum_{j=1}^n e^{\beta k^{\top} k_j}} v_i f(k;{(k1,v1)...(kn,vn)})=∑i=1n∑j=1neβk⊤kjeβk⊤kivi. (3)
当然,还有更有利的方法来实现这样的关联记忆,从快速高斯变换到局部敏感哈希。尽管这些方法在未来肯定会被证明是有用的,但是该研究只依赖于用高斯核平滑实现的关联记忆,不仅是因为这使得梯度的计算变得容易。
用关联记忆预测 现在考虑一个观测序列 ( x t ) (x_t) (xt),离散token或连续值。该研究想利用过去的观测值 ( x t ) t ≤ T (x_t)_{t \leq T} (xt)t≤T来预测未来观测值 ( x t ) t > T (x_t)_{t > T} (xt)t>T的某些有用属性。例如,该研究可能想预测下一个观测值 x T + 1 x_{T+1} xT+1来构建序列的自回归模型。
该研究的基本记忆单元(图1)包括一个关联记忆和一个可训练的特征提取器,为记忆计算合适的键和值。键 k T k_T kT根据过去的观测值 ( x t ) t ≤ T (x_t)_{t \leq T} (xt)t≤T和可训练的权重 w w w计算,
k T = ϕ ( x T , x T − 1 , . . . ; w ) k_T = \phi(x_T, x_{T-1}, ...; w) kT=ϕ(xT,xT−1,...;w). (4)
相比之下,值 v T v_T vT被允许窥视未来,因为它们代表记忆模块旨在预测的内容。例如,该研究描述的系统仅允许值依赖于下一个观测值 x T + 1 x_{T+1} xT+1,
v T = ψ ( x T + 1 , x T , x T − 1 , . . . ; w ) v_T = \psi(x_{T+1}, x_T, x_{T-1}, ...; w) vT=ψ(xT+1,xT,xT−1,...;w). (5)
在任何给定时间 T T T,关联记忆包含先前观察到的对 ( k t , v t ) t ≤ T − 1 (k_t,v_t)_{t \leq T-1} (kt,vt)t≤T−1,并通过查询新计算的键 k T k_T kT来检索尚未知道的值 v T v_T vT的估计 y T y_T yT。一个时间步长之后,观测值 x T + 1 x_{T+1} xT+1变得可用,值 v T v_T vT可以计算,对 ( k T , v T ) (k_T,v_T) (kT,vT)被存储到记忆中。
特征提取函数的确切形式可能在复杂性上有相当大的变化。例如,当每个观测值 x T x_T xT携带足够的信息时,键 k T k_T kT和值 v T v_T vT可以分别计算为 x T x_T xT和 x T + 1 x_{T+1} xT+1的线性函数
k
T
=
W
ϕ
x
T
k_T = W_\phi x_T
kT=WϕxT
v
T
=
W
ψ
x
T
+
1
v_T = W_\psi x_{T+1}
vT=WψxT+1. (6)
然而,考虑更复杂的特征提取函数通常很有吸引力,包括卷积、leaky average、layers、转移函数、归一化,甚至多项式时间序列汇总。例如,第5节的语言实验使用以下形式的特征提取器,其中 Norm ( x ) = x / ∣ ∣ x ∣ ∣ \operatorname{Norm}(x) = x/||x|| Norm(x)=x/∣∣x∣∣:
k T = α ϕ Norm ( k ˉ T ) w i t h k ˉ T = k ~ T + λ ϕ k ˉ T − 1 , k ~ T = W ϕ x T ⏞ 在 t = T, T-1, ..., 1 上的 leaky average v T = α ψ Norm ( v ˉ T ) w i t h v ˉ T = v ~ T + λ ψ v ~ T + 1 , v ~ T = W ψ x T ⏟ 在 t=T 和 T+1 上的卷积 \begin{aligned} k_T &= \alpha_\phi \operatorname{Norm}(\bar{k}_T) \quad with \quad \overbrace{\bar{k}_T = \tilde{k}_T + \lambda_\phi \bar{k}_{T-1}, \quad\tilde{k}_T = W_\phi x_T}^{\text{在 t = T, T-1, ..., 1 上的 leaky average}}\\ \\ v_T &= \alpha_\psi \operatorname{Norm}(\bar{v}_T) \quad with \quad \underbrace{\bar{v}_T = \tilde{v}_T + \lambda_\psi \tilde{v}_{T+1}, \quad\tilde{v}_T = W_\psi x_T }_{\text{在 t=T 和 T+1 上的卷积}}\\ \quad {\quad} \end{aligned} kTvT=αϕNorm(kˉT)withkˉT=k~T+λϕkˉT−1,k~T=WϕxT 在 t = T, T-1, ..., 1 上的 leaky average=αψNorm(vˉT)with在 t=T 和 T+1 上的卷积 vˉT=v~T+λψv~T+1,v~T=WψxT (7)
训练记忆单元网络 现在考虑一个深度网络,其中包含胶水层和基本记忆单元,在胶水层和记忆单元的特征提取器中都有可训练的权重。总之,这些权重决定每个记忆单元记忆什么(通过键值提取函数)以及如何组合它们的输出来执行整体预测任务。
当使用可微分的核平滑机制实现关联记忆时,训练这样一个深度网络只是一个在时间上展开网络并反向传播梯度的问题,这种方式对于现代深度学习软件的用户来说非常熟悉。不足为奇的是,沿输入序列 ( x 1 . . . x D ) (x_1...x_D) (x1...xD)展开等式(3),得到的表达式与 masked self-attention 非常相似。
∀ T ∈ { 1... D } y T = ∑ i = 1 T − δ e β k T ⊤ k i ∑ j = 1 T − δ e β k T ⊤ k j v i \forall T \in \{1...D\} \quad y_T = \sum_{i=1}^{T-\delta} \frac{e^{\beta k_T^{\top} k_i}}{\sum_{j=1}^{T-\delta} e^{\beta k_T^{\top} k_j}} v_i ∀T∈{1...D}yT=i=1∑T−δ∑j=1T−δeβkT⊤kjeβkT⊤kivi (8)
然而,在(8)和经典 self-attention 之间有显著差异:
- 更激进的 masking 涉及计算 v T v_T vT所需的未来时间步数 δ ≥ 1 \delta \geq 1 δ≥1。
- 不需要位置编码或区分键向量和查询向量。
- 每个记忆单元的最终目的是可观察的,因为在每个时间 T T T,其输出 y T y_T yT可以解释为条件期望(1),它根据过去的观测值预测尚未知道但明确的量 v T = ψ ( ( x t ) t ≤ T + δ ; w ) v_T = \psi((x_t)_{t \leq T+\delta}; w) vT=ψ((xt)t≤T+δ;w)。
学习和元学习 这个训练过程实际上应该被视为一个元学习过程,与推理时发生的学习不同,推理时新的键值对被存储到记忆中。在完整的输入序列 ( x 1 . . . x D ) (x_1...x_D) (x1...xD)上对网络进行时间展开,揭示了记忆的运行情况。由于总体预测成本在训练序列的每个时间步长评估项的平均值,训练过程有利于统计效率的记忆,能够在存储尽可能少的键值对后输出有用的值估计 y T y_T yT(图2)。
-
首先假设每个记忆单元有一个固定的值提取函数 ψ \psi ψ。训练过程仍然可以通过调整键提取函数 ϕ \phi ϕ的参数,即学习如何将当前预测上下文 ( x T , x T − 1 , x T − 2 . . . ) (x_T,x_{T-1},x_{T-2}...) (xT,xT−1,xT−2...)与过去的预测上下文 ( x t , x t − 1 , x t − 2 . . . ) (x_t,x_{t-1},x_{t-2}...) (xt,xt−1,xt−2...)进行比较,其中 t < T t<T t<T,从而使每个记忆单元在统计上更有效率。学习一个相似度度量(一个核)是一种众所周知的使非参数估计量更有效的方法。例如,训练过程可以构造键来总结相关的上下文信息,丢弃可能增加与相似值相关的键之间距离的噪声因素。它还可以调整有效的核带宽,例如,使用等式(7)中的参数 α ϕ \alpha_\phi αϕ。
-
当有多个记忆单元可用时,训练过程还可以优化值提取函数 ψ \psi ψ,在可用的记忆单元之间重新分配整体预测工作。例如,训练过程可以构造值 v T v_T vT,使它们各自的记忆单元更有效地近似,但其近似仍然可以组合以解决整体预测任务。
该研究在这个工作中论证,这种重新分配不仅使记忆单元更有效率,而且还将原始预测问题解耦为基本子问题,这些子问题后来可以以无数方式重新组合。
预测解耦 该研究已经论证,训练时间展开的记忆单元网络实际上是一个元训练过程,它学习如何有效地使用记忆单元来解决感兴趣的任务。当总体预测任务可以分解为较小的预测任务时,使用多个记忆单元而不是单个记忆单元是有益的,这些较小预测任务的目标由值 v T v_T vT表示,并且可以更有效地单独预测而不是一起预测。然后可以重新组合解耦的记忆,为全局上与训练输入非常不同但其解耦组件可以单独预测的输入提供预测。第3节提供了一个玩具示例。
解耦一直被认为是既理想又难以确定的。统计定义在改变数据分布时缺乏稳健性。没有主动实验,因果定义无法测试。预测解耦提供了一个吸引人的替代方案,它不需要特定的训练算法,而是作为元训练过程的副作用出现。尽管在关联记忆单元网络的情况下,预测解耦更容易理解,但该研究可以安全地推测,类似的现象也发生在transformer中,尽管方式远没有那么明确。
3 跟踪三个月亮
三个月亮绕着一个遥远的行星运行。尽管当地的天文学家离理解天体力学还很远,2但他们仍然观察到周期性运动,并争论如何预测未来月亮位置。一位天文学家建议编制一张包含所有三个月亮每日位置的表格,认为如果当前的月亮位置与之前的观测相匹配,未来的月亮位置将与随后的观测相匹配。另一位天文学家则建议制作三张表格,每个月亮一张,认为每个月亮的未来位置可以通过将其当前位置与先前观察到的位置进行匹配来独立预测。
为了做出可靠的预测,第一位天文学家需要一个表格,其中至少包含每个可能的月亮配置的一条记录。因此,我们的天文学家需要记录每日月亮位置,直到所有三个月亮返回到它们的原始配置,这需要的天数等于单个月亮周期的最小公倍数 l c m ( p 1 , p 2 , p 3 ) lcm(p_1,p_2,p_3) lcm(p1,p2,p3)。相比之下,第二位天文学家只需要记录每日月亮位置,直到每个月亮返回到先前观察到的位置,这需要的天数等于最慢月亮的周期 m a x ( p 1 , p 2 , p 3 ) max(p_1,p_2,p_3) max(p1,p2,p3)。
有人可能会争辩说,第二位天文学家的建议显然更优越,因为三个月亮是不同的物体,在空间和时间上都是分开的。也可以说,我们之所以将月亮视为独立的物体,正是因为它们各自的未来通常可以独立预测。只要月亮不碰撞,空间和时间分离仅仅表明独立预测的可能性。
模型 出于该研究的目的,每个观测值 x t x_t xt由三个复数 e i θ k e^{i\theta_k} eiθk组成,编码三个月亮在各自轨道平面内的角度位置 θ k \theta_k θk。我们考虑两个单层模型(图3),其中有 N h = 1 N_h=1 Nh=1或 N h = 3 N_h=3 Nh=3个记忆单元,其添加的维度与输入维度匹配。键和值提取函数遵循(6),可训练参数收集在两个 3 × 3 3 \times 3 3×3复数矩阵 W ϕ W_\phi Wϕ和 W ψ W_\psi Wψ中。记忆单元遵循(3),固定参数 β = 50 \beta=50 β=50。第三个 3 × 3 3 \times 3 3×3复数矩阵 W z W_z Wz将记忆单元预测组合成一个输出 z T z_T zT,希望预测 x T + 1 x_{T+1} xT+1。两个网络共享一个有趣的解析解:将所有三个矩阵 W ϕ W_\phi Wϕ、 W ψ W_\psi Wψ和 W z W_z Wz设置为单位矩阵,一旦关联记忆看到足够的样本,就会产生最优预测。
训练 使用长度为800的随机生成序列 ( x t ) (x_t) (xt)训练网络。每个序列包含三个月亮,它们的周期由随机选择的比率相关,并进行缩放以确保800个观测序列包含至少三个完整的月亮系统周期 l c m ( p 1 , p 2 , p 3 ) lcm(p1,p2,p3) lcm(p1,p2,p3)。验证序列的构造类似,使用一组不出现在训练集中的月亮周期。
图4和图5显示了两个网络的预测误差作为上下文长度(即存储在记忆中的观测值数量)的函数。更准确地说,对于每个序列 ( x t ) (x_t) (xt)和每个时间索引 T T T,我们计算接下来25个真实月亮位置 x T + 1 . . . x T + 25 x_{T+1}...x_{T+25} xT+1...xT+25和接下来25个自回归预测之间的平均绝对偏差(其中连续预测被循环反馈到网络输入中)。图中显示了在共享相同月亮周期集的512个序列上平均的曲线,这些序列来自训练集或验证集。
-
对于单头网络(图4),图中显示在 l c m ( p 1 , p 2 , p 3 ) lcm(p_1,p_2,p_3) lcm(p1,p2,p3)次观测后出现急剧转变(红色垂直线),此时网络从通过重复上一次观测来预测未来月亮位置,转变为通过寻找匹配的记忆月亮配置来预测。
-
对于三头网络(图5),每当上下文长度足以完全描述每颗行星的轨道时,预测误差曲线就会下降。网络在观察到 m a x ( p 1 , p 2 , p 3 ) max(p_1,p_2,p_3) max(p1,p2,p3)(最后一条黑色垂直线)后就产生准确的预测,即在看到完整的月亮配置集合(红线)之前。这是可能的,因为这些未见过的月亮配置是通过组合单个月亮预测预测的。
图6描绘了学习到的权重矩阵,显示网络如何成功地解耦三个月亮( W ψ W_\psi Wψ和 W ϕ W_\phi Wϕ)并学习如何重新组合单个预测( W z W_z Wz)。
训练注释 训练三头网络可能相当具有挑战性,这让人想起早期的XOR网络。我们使用两个技巧获得了可靠的收敛。首先,我们使用 3 × 3 3 \times 3 3×3复数矩阵(18个实参数)而不是对3维复向量作为6维实向量进行操作的 6 × 6 6 \times 6 6×6实矩阵(36个实参数)来稍微限制线性运算。其次,为了防止训练算法在记忆几乎为空时尝试优化预测误差,我们对均方损失进行了裁剪。
也可以通过使 W ϕ W_\phi Wϕ、 W ψ W_\psi Wψ或 W z W_z Wz等于单位矩阵来可靠地实现收敛。这样做当然会使网络偏向解耦解,我们想避免这种情况。然而,相信解耦通常可以在规范基础上实现是不无道理的。例如,在空间上分离良好的对象通常出现在图像的不同区域,因此沿着不同的像素轴。
4 分层记忆
为了激发更复杂的记忆单元网络,让我们通过添加一个太阳和几个行星及其各自的卫星来挑战我们的天文学家。每隔一段时间,也许只是在他们的想象中,我们的天文学家被传送到一个新的行星系统,观察新的天空,必须整理一组新的天体。
持久记忆 各种提示可以为新的行星系统建议一个结构:太阳具有独特的外观;内行星保持靠近太阳;远行星的卫星从不远离它们的中心等。我们的天文学家还需要几何知识来确定要制表的内容以及如何组合制表的量并预测未来的天空。我们称这种知识为持久的,因为它适用于所有行星系统,与特定行星系统先前观察到的天空等上下文知识相反。
在图3的三月亮架构中,持久知识存储在每个记忆单元层的特征提取器的参数 W ϕ W_\phi Wϕ、 W ψ W_\psi Wψ和混合层的参数 W z W_z Wz中。这些参数在训练时通过梯度反向传播和优化来确定。这种简单的线性函数可能不再足以编码所需的持久知识。特征提取函数和组合函数可能必须执行非线性计算,例如,反映实际行星位置到天球上的投影。
遵循Sukhbaatar的方法,持久知识可以存储在持久关联记忆中,其中包含预定数量的键值对 ( k i , v i ) i = 1... N m (k_i,v_i)_{i=1...N_m} (ki,vi)i=1...Nm,其值在训练时通过梯度反向传播确定,并在推理时保持固定。持久记忆单元(图7)因此与上下文记忆单元(图1)非常相似,但依赖于持久关联记忆。持久记忆单元不再需要显式值提取函数,因为在推理时不更新记忆内容。正如Sukhbaatar所指出的,它们也可以被视为具有单个隐藏层的全连接神经网络,该隐藏层使用softmax非线性而不是分量转移函数。交错上下文和持久记忆单元层然后可以理解为增加上下文记忆的特征提取器或组合层的有效复杂性的手段(剧透见图8)。然而,我们还发现在概念上将持久记忆输出 y t y_t yt视为隐式值函数的条件期望很有用,该函数没有显式参数化,但可以在训练后确定。
路由 我们的天文学家可以设计什么样的解决方案来预测新行星系统的天空?他们也许首先制表太阳位置,然后是每个行星相对于太阳的位置,然后是每个卫星相对于其行星的位置。为了实现这样的解决方案,记忆单元网络不仅需要计算所有这些相对位置,而且还必须在表示每个这些表的记忆单元之间路由信息。此外,由于同一网络必须处理与不同组成的不同行星系统相关的观测序列,这些路由必须根据每个新观测序列的第一次观测动态改变。
尽管这种动态路由问题似乎难以克服,但我们首先可以观察记忆单元网络如何实现其行为在训练时确定的静态路由电路。静态路由电路使用每个记忆单元层的特征提取器和组合层来实现(图3),两者都通过交错的持久记忆单元层增强了非线性操作。就像专家混合的门控模块一样,这样的路由电路可以实现依赖于数据的路由。然而,我们仍然称这种路由为静态的,因为所有可能的路由都是在训练期间预先确定的。
为了获得动态路由能力,训练算法可以招募上下文记忆单元来替换或补充参与路由电路的持久记忆单元。因为上下文记忆单元的内容在推理时更新,将其中一些招募到路由电路中提供了根据新序列的第一次观测创建新路由的手段,这表明了胶囊网络的替代方案。
记忆镶嵌 在这样一个复杂的网络中,上下文记忆单元之间的分工仍然由预测解耦原则决定。在训练过程中,图2中的蒸汽压路机将上下文记忆单元推向更容易独立记忆而不是聚合记忆的函数。这不仅适用于记录主要信息片段(如第3节中的月亮位置)的记忆单元,而且也适用于影响路由电路和对早期记忆单元产生的信息进行操作的记忆单元。
因此,在预测解耦原理的压力下,记忆单元网络不仅记忆解耦的信息片段,而且还记忆它们如何组合在一起,以及它们的组合如何再次分解为新的解耦片段并以无数方式重新组合。这就是为什么我们称这种网络为记忆镶嵌。
5 从行星系统到语言建模
到目前为止,我们已经将记忆镶嵌描述为一种在重要方面类似transformer的架构,其基本属性由于预测解耦原则而相对透明。为了使这个愿景更有说服力,我们提供证据表明记忆镶嵌可以处理transformer最成功的应用,即语言建模。
让我们首先论证预测解耦是语言建模的一个有意义的概念。在对话的情况下,预测接下来谁会说话在很大程度上可以独立于所说的内容来预测。同样适用于预测下一段对话是用英语还是德语表达,参与者是否一致,或者他们的语法是正式还是非正式。解耦这些预测提供了利用更多训练数据的机会。例如,可以预测特定的对话参与者会给出紧张的答案,因为几乎任何置身于相同情况的对话参与者都会给出紧张的答案。
语言建模任务 Eldan和Li的TinyStories数据集提供了使用相对较小的语言模型来研究通常只出现在更大模型中的语言建模现象的手段。这是通过将文本限制在三岁孩子可以理解的简单语言中并描述发生在三岁孩子可以理解的简单世界中的故事来实现的。借助这个有限的范围,Eldan和Li展示了一个在tiny stories语料库上训练的小型语言模型(33M参数)生成的延续在语言质量和叙事一致性方面远远优于在通用web语料库上训练的更大模型(1.5B参数)。
在Eldan和Li的带领以及法律部门的告诫下,该研究使用Mixtral-8x7B开放语言模型生成了一个新的tiny stories语料库,称为BabiStories。该语料库及其生成在附录A中详细描述。
架构 为了将实验置于上下文中,设计了一个与经典GPT2-small transformer架构密切匹配的记忆镶嵌架构。如图8所示,两个架构并排使用相同的GPT2分词器、相同的嵌入维度( d = 768 d=768 d=768)和相同数量的头( N h = N c = N p = 12 N_h=N_c=N_p=12 Nh=Nc=Np=12)。两种架构都使用长度为512的序列进行训练和测试,即一到三个故事长。
这两种架构之间有三个主要区别。首先,记忆镶嵌不使用位置编码。其次,与每个transformer块的 N h = 12 N_h=12 Nh=12个注意力头不同,每个块中的 N c = 12 N_c=12 Nc=12个上下文记忆单元不区分键和查询(图1),而是使用等式(7)中描述的键和值提取函数。键用过去输入的leaky average形成,值可以提前一个时间步长查看。因此,注意力掩码排除主对角线以避免破坏因果性。最后,经典transformers块的前馈网络(FFN)被 N p = 12 N_p=12 Np=12个持久记忆单元的层替换,完整地包含键提取函数(7)和组合层,其大小确保记忆镶嵌架构的每块参数数量与GPT2-small密切匹配。
训练和验证 图9显示了在BabiStories上训练的不同深度的transformers和记忆镶嵌的训练和验证曲线。对于小深度网络,记忆镶嵌略微优于transformer,5,但是当深度增加时,这种效果消失,训练和验证损失变得无法区分。
重要的是,所有超参数都是针对transformer架构调整的(附录B),并逐字传输到记忆镶嵌中。这种选择可能解释了训练曲线如此紧密地跟踪对方。它还让记忆镶嵌在这个比较中处于轻微劣势。
定性评估 为了比较在tiny stories上训练的模型生成文本的质量,Eldan和Li设计了二十四个提示,用于测试生成延续的事实、逻辑和一致性属性。表1比较了由深度为 N b = 18 N_b=18 Nb=18的transformer和记忆镶嵌在这些提示上生成的延续。两个模型在这项任务上的表现非常相似。
分布外评估 Simple English Wikipedia是维基百科的一个版本,用更容易理解的语言编写。尽管有预期的简单性,这些文章实质上比我们的BabiStories更长、更复杂。因此,使用在BabiStories上训练的模型预测Simple English Wikipedia文章是一项具有挑战性的分布外任务。
图10显示了每个token的平均损失作为生成token在输入窗口中位置的函数。transformer和记忆镶嵌都是 N b = 512 N_b=512 Nb=512块深。在这个实验中,当增加的上下文大小揭示分布是不同的时,预期token预测会改善。transformer性能在100到150个token后达到平台期,这比典型的tiny story稍短。在大约50个token之后,记忆镶嵌大大优于transformers,表明其具有更强的上下文学习能力。
上下文学习评估 为了严格比较各种架构的上下文学习能力,RegBench基准构建了大量随机构建的概率有限自动机(PFA)定义的人工语言。每个输入序列由10到20个字符串组成,这些字符串从同一个PFA中抽取并用分隔符token分隔。竞争架构在可变数量的输入序列上训练,然后评估它们预测使用保留PFA生成的测试序列的最后一个token的能力。
由于RegBench为每个数据点执行超参数搜索,使用图8中的记忆镶嵌架构,与transformers使用相同的搜索空间,确保transformers和记忆镶嵌对于相同的架构超参数具有相同的参数数量。扫描深度 N b ∈ 2 , 4 , 8 N_b \in {2,4,8} Nb∈2,4,8,头数 N h = N c = N p ∈ 2 , 4 , 8 N_h=N_c=N_p \in {2,4,8} Nh=Nc=Np∈2,4,8,嵌入维度 d ∈ 64 , 128 , 256 d \in {64,128,256} d∈64,128,256,权重衰减 ∈ 1 0 − 2 , 1 0 − 1 \in {10^{-2},10^{-1}} ∈10−2,10−1,以及训练epoch ∈ 1 , 2 , . . . 200 \in {1,2,...200} ∈1,2,...200。
图11将记忆镶嵌在RegBench上与Akyürek et al.先前报告的结果进行比较。左图显示测试字符串最后一个token的预测精度。右图将预测的最后一个token分布与PFA隐含的精确分布进行比较。记忆镶嵌在这个基准上占主导地位,在覆盖三个数量级的训练集大小上大大优于transformers、循环神经网络和状态空间模型。
注意力差异 因为记忆镶嵌缺乏位置编码并且不区分键和查询,我们研究了它们的注意力模式与transformers的差异。图12显示了单块深度的transformer使用绝对位置编码(左图)或单块深度的记忆镶嵌(右图)的每个头的注意力分数。这些分数在5000个BabiStories序列上取平均值,显示最后一个位置如何关注512个token长的上下文窗口中较早的位置。transformer的注意力模式是嘈杂的,在位置0处有一个强"注意力汇"。相比之下,记忆镶嵌的注意力模式大多是平坦的,除了最近token的分数较高。
图13显示了扩展到1536个token的上下文的注意力模式,使用在512个token长的序列上训练的模型。因为绝对位置编码方案无法扩展到更长的上下文,我们提供了与使用RoPE [Su et al., 2024]和AliBi [Press et al., 2022]的transformers的比较。RoPE注意力模式在超出训练上下文长度时无法很好地扩展。AliBi注意力模式显示了远token的贡献减弱。相比之下,记忆镶嵌的注意力模式基本保持平坦。
6 讨论
这项工作的出发点由两个非常古老的想法组成。第一个是用显式记忆增强深度网络。第二个是让学习过程决定记忆什么以及如何检索它。尽管这些想法已经在记忆网络中得到探索,但拥有大量独立记忆的重要性尚未得到充分认识。
这项工作重点关注使用核平滑实现的关联记忆网络,因此适合于基于梯度的学习算法。这种学习机器不仅类似于解码transformers(第2节),而且在使它们出名的语言建模任务上的表现也非常类似于解码transformers(第5节)。尽管需要大量工作才能在更大的规模上复制观察结果,但记忆镶嵌满足叙事约束的程度与transformers一样好(表1),并且通常表现出非常鼓舞人心的方式(图10至13)。
最重要的是,该研究对记忆镶嵌的理解远远超过我们对transformers的理解。首先,关联记忆单元的值提取函数精确描述了每个记忆试图记忆的内容。其次,预测解耦原则解释了为什么训练记忆镶嵌将整体预测任务分解为当独立考虑时比聚合考虑时更有效记忆的部分(第3节)。因此,记忆镶嵌不仅是一种类似transformer的架构,而且还是组合学习系统的一个模型9,它将知识分解为独立记忆的片段,然后根据需要重新组装它们,使用可以被视为记忆知识片段的组合策略(第4节)。
关注记忆让我们能够提出新的问题。记忆是否可以在不同的时间尺度上独立运作?能否设想比简单区分持久记忆和上下文记忆更丰富的记忆层次结构?中间记忆层是否可以像上下文记忆一样训练,即不使用梯度?持久知识能否被简化为紧凑的高阶偏置?
记忆镶嵌还提供了一系列工程机会。有限存储的上下文记忆可以通过驱逐最近最少使用的条目而不是通过限制上下文大小来实现。关联记忆可以使用广泛的技术来实现,从快速变换到局部敏感哈希机制,这可能会改变当代人工智能系统的计算要求。
参考论文出处: https://arxiv.org/pdf/2405.06394
译注:
- 本文没有讨论椭圆轨道或多体问题等微妙之处。我们的天文学家最好与那些努力最终产生托勒密模型的古代观星者相比。
- 这个效应与 leaky average 系数 λ ϕ \lambda_\phi λϕ 有关,如附录(图 15)所示。
- 当损失是有界的时,蒸汽压路机的比喻(图 2)更有意义。
- 与 GPT2-small 相比,我们节省了 768 × 512 768 \times 512 768×512 位置编码权重和 N b × 76 8 2 N_b \times 768^2 Nb×7682 查询投影权重,并为持久记忆键提取函数和混合层添加了 2 × N b × 76 8 2 2 \times N_b \times 768^2 2×Nb×7682 权重。因此,持久记忆单元插槽的总数接近但不完全等于 FFN 隐藏单元的数量。
- 这并不奇怪,因为记忆马赛克架构只需要一个块来正确处理 Bietti et al. [2024] 的归纳头问题,而 transformer 架构需要两个块来完成相同的任务。
- https://simple.wikipedia.org/wiki/Simple_English_Wikipedia
- 在低训练环境情况下(例如100),所有在RegBench OOD测试集上表现不佳的基线方法实际上可以在IID测试集上表现得很好。因此,基线方法学习了训练环境(良好的IID),但没有学习元学习能力(糟糕的OOD)。请查看附录表4了解详情。
- 不是"统计模型",而是用来描述和解释现象的"模型"。
附录
A BabiStories
TinyStories数据集由用简单语言编写并发生在狭窄世界中的故事组成。这些故事可用于训练相对较小的语言模型,这些模型仍然必须解决一些更广泛的语言建模挑战,例如遵守叙事必要性和保持逻辑一致性。这个数据集是一种用可接受的计算和快速周转来研究大问题的绝佳方式。
第 5 节的实验是使用使用类似方法生成的数据集进行的,但使用 Mixtral-8x7B 开放语言模型生成无责任数据。我们将这个数据集称为 BabiStories。所有的科学荣誉仍然归功于 Eldan 和 Li 的出色工作。
表 2 提供了这个新生成的 BabiStories 数据集的基本统计信息,基本上与 Eldan 和 Li 的原始 TinyStories 数据集相匹配。我们不得不通过扩展提示以指定名字并提供故事的开头词来增加生成故事的多样性,除了 Eldan 和 Li 使用的所需单词和故事特征(图 14)。我们还删除了少数包含 URL 的生成故事。
图14: BabiStories的生成。为了提高生成的多样性,每个故事都由一个提供所需单词和故事特征列表的提示生成(如Eldan和Li, 2023),并且还提供名字和开头词。
表2: BabiStories统计
B GPT2基线和超参数
表3展示了在BabiStories数据集上GPT2 transformer基线的超参数搜索过程,其中我们使用AdamW优化器,批量大小512,上下文大小512,以及具有最小学习率1e-4的余弦学习率调度器进行所有训练。
表3: 在 N b = 12 N_b=12 Nb=12的GPT2 transformer上进行超参数搜索。如果有的话,"dropout"应用于注意力分数、注意力头输出(在组合层之前)和FFN输出。
C 用于语言建模的记忆镶嵌
C.1 额外的定性评估
表5提供了与第5节表1类似的延续生成比较,但设置 N b = 1 N_b=1 Nb=1。
C.2 注意力图和leaky average系数 λ ϕ \lambda_\phi λϕ
图15显示了注意力图和leaky average系数 λ ϕ \lambda_\phi λϕ之间的关系。
(图15: 注意力图和leaky average系数 λ ϕ \lambda_\phi λϕ。随着 λ ϕ \lambda_\phi λϕ增加, k t k_t kt在等式7中有效地考虑了更长的历史,因此注意力图末端的峰值变得更宽。)
C.3 上下文语言学习
表4提供了在RegBench上仅使用100个训练环境训练的各种架构的IID测试性能。我们保持训练过程,包括超参数搜索空间,与图11中的相同。但从与训练集相同的100个概率有限自动机(训练环境)中采样验证集和测试集。该表与图11一起表明,基线方法学习了训练环境(良好的IID),但没有学习元学习能力(糟糕的OOD)。
表4:在RegBench上仅使用100个训练环境训练的各种架构的分布内(IID)性能。训练集、验证集和测试集(100个样本)都从相同的100个随机概率有限自动机(PFA)中采样。与图11中基线方法的糟糕OOD准确率(0.45)/TVD(0.75)相比,所有基线方法在IID测试集上表现良好(即使只有100个训练环境)。
(表5与表1类似,比较了 N b = 1 N_b=1 Nb=1的情况下两个模型生成的延续)
这篇论文提出了一种新的学习系统架构——记忆镶嵌,它由多个关联记忆网络协同工作来完成预测任务。记忆镶嵌具有与transformer类似的组合和上下文学习能力,但以更透明的方式实现。
我认为这项工作有几个有趣和有价值的地方:
- 记忆镶嵌利用了平滑关联记忆和self-attention之间的相似性,这为理解和改进transformer型模型提供了新的视角。
- 预测解耦原则解释了训练如何将整体任务分解为可以更高效学习的子问题,这对理解复杂神经网络的工作机制很有启发。
- 在语言建模任务上,这种相对透明的架构取得了与transformer相当的性能,表明了其实用潜力。
- 与transformer相比,记忆镶嵌的内部机制更容易理解。值提取函数精确描述了每个记忆单元记忆的内容。这使得记忆镶嵌成为组合学习系统的一个很好的研究模型。
- 作者还讨论了一些有趣的未来研究方向,如更丰富的记忆层次、动态路由等,这为后续工作指明了方向。
不过这项工作也有一些局限和值得进一步探索的地方:
- 目前的实验规模还较小,在大规模数据和任务上的表现还有待考察。记忆镶嵌能否像transformer一样具有良好的扩展性还需要更多实践检验。
- 一些关键的细节,如记忆单元的具体实现、路由机制等,论文中的讨论还比较简略,值得进一步细化。
- 尽管记忆镶嵌在可解释性上比transformer有优势,但其内部机制的直观理解还比较初步,需要更多理论和实证分析。
- 如何在记忆镶嵌的框架下实现few-shot learning、transfer learning等高级学习能力还有待探索。
我认为这项工作从关联记忆和组合学习的角度,为理解和改进transformer式模型提供了新的视角和思路,是一项有洞见和潜力的研究。未来如果能在更大规模、更多类型的任务上取得突破,记忆镶嵌有望成为深度学习领域一个重要的研究方向。同时其思路也可以启发其他形式的组合学习系统的设计。