先从PosWiseFFN说起
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc = nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.GeLU(),
nn.Linear(d_ff, d_model, bias=False))
def forward(self, inputs): # inputs: [batch_size, seq_len, d_model]
residual = inputs
output = self.fc(inputs)
return nn.LayerNorm(d_model)(output + residual) # [batch_size, seq_len, d_model]
如果Attention的维度是d_model,通常PosWiseFFN模型结构就是2个矩阵中间加个Gelu,d_ff是d_model的4倍:第1个矩阵的weight是[d_model, 4*d_model],第2个矩阵的的weight是[4*d_model, d_model]。
PosWiseFFN这个结构也可以理解成一种qkv查询的思路,如果第1个矩阵理解成key,第二矩阵理解成value,那么输入就是[batch_size, seq_len, d_model]的input作为query先去和key做矩阵乘法,得到一个[batch_size, seq_len, 4*d_model]的dots,这个dots过了GeLU后再去和[4*d_model, d_model]的第二个矩阵相乘,这一步变向取了前d_model重要的结果。问题来了,能不能把 4*d_model的d_ff给变得更大呢?Figure 1来自Large Memory Layers with Product Keys的Figure1,图里的|K|在PosWiseFFN里就是 4*d_model。
下面的PKM简单来说就是把这种qkv查询的思路借用PQ的思想给改进了
PKM(Product Key Memory,这个Product其实就是Product Quantization的Product)
在Large Memory Layers with Product Keys的Figure1里,q的shape是[…,d_model],k的shape是[d_model, |K|],下面看Figure2里怎么解决|K|过大的问题?图里把d_model维的q劈成q1和q2,q1和q2的维度分别是d_model/2;同样的,把[d_model, |K|]的keys劈成[d_model/2, |K|]的sub-key set 1(下图里不带’的
c
1
c_1
c1,
c
2
c_2
c2,
c
3
c_3
c3)和[d_model/2, |K|]的sub-key set 2(下图里带’的
c
1
′
c^{'}_1
c1′,
c
2
′
c^{'}_2
c2′,
c
3
′
c^{'}_3
c3′)。这样两半都出topk,最后从
k
2
k^2
k2里再选出k个,这就是Product Quantization的思想
代码赏析
代码来自https://github.com/lucidrains/product-key-memory/tree/master,里面einops用的不错,下面给一些注释:
class PKM(nn.Module):
def __init__(
self,
dim,
heads = 4,
num_keys = 128,
topk = 32,
dim_head = 128,
input_dropout = 0.,
query_dropout = 0.,
value_dropout = 0.,
attn_dropout = 0.,
use_layernorm = True,
pre_layernorm = False,
differentiable_topk = False,
concat_values_and_combine = False,
norm_output = False,
non_competitive_gates = False # Csordas et al. claims non-competitive gates work even better
):
super().__init__()
self.topk = topk
self.heads = heads
self.num_keys = num_keys
dim_query = dim_head * heads * 2
self.to_queries = nn.Linear(dim, dim_query, bias = False)
# pre-layernorm pattern
self.pre_layernorm = nn.LayerNorm(dim) if pre_layernorm else nn.Identity()
# batchnorm would break causality
self.use_layernorm = use_layernorm
if use_layernorm:
self.norm = nn.LayerNorm(dim_head)
else:
self.norm = MaskedBatchNorm1D(nn.BatchNorm1d(dim_head))
# keys
self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head))
init_(self.keys)
# values
self.concat_values_and_combine = concat_values_and_combine
if concat_values_and_combine:
values = nn.Embedding(num_keys ** 2, dim_head)
self.values = nn.Sequential(
values,
Reduce('b (h k) d -> b h d', 'sum', h = heads),
Rearrange('b n d -> b (n d)'),
nn.Linear(dim_head * heads, dim, bias = False)
)
else:
values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
self.values = values
init_(values.weight)
# dropouts
self.input_dropout = nn.Dropout(input_dropout)
self.query_dropout = nn.Dropout(query_dropout)
self.value_dropout = nn.Dropout(value_dropout)
self.attn_dropout = nn.Dropout(attn_dropout)
# non competitive gates
self.gate_activation = nn.Softmax(dim = -1) if not non_competitive_gates else nn.ReLU()
# use a differentiable topk, based on coordinate descent
self.differentiable_topk = differentiable_topk
# https://arxiv.org/abs/2302.06461
# claims to boost performance of softmax key / value networks by simply layernorming the output
self.output_norm = nn.LayerNorm(dim) if norm_output else nn.Identity()
def forward(
self,
x,
input_mask = None,
gumbel_noise_scale = 0.,
**kwargs
):
b, t, h = *x.shape[:2], self.heads
x = self.pre_layernorm(x)
x = self.input_dropout(x)
queries = self.to_queries(x)
#写一下queries的shape: b=batch_size, t=target_seq_len, p=partition, h=num_heads, d=head_dim
queries = rearrange(queries, 'b t (p h d) -> (b p h) t d', p = 2, h = h)
# norm and dropout queries
norm_kwargs = dict(mask = input_mask) if not self.use_layernorm else dict()
queries = self.norm(queries, **norm_kwargs)
queries = self.query_dropout(queries)
queries = rearrange(queries, '(b p h) t d -> p b t h d', p = 2, h = h)
# similarity to keys
# keys.shape:heads, num_keys, 2, dim_head。这里的n是keys的batch_size
# 这里的keys本质上是一个二维数组
dots = einsum('p b t h d, h n p d -> b t h p n', queries, self.keys)
# gumbel noise
if gumbel_noise_scale > 0.:
dots = dots + gumbel_noise(dots) * gumbel_noise_scale
# topk scores
if self.differentiable_topk:
scores, indices, *_ = coor_descent_topk(dots, k = self.topk, fused = True)
else:
scores, indices = dots.topk(k = self.topk, dim = -1)
# scores are factorized
(scores_x, scores_y), (indices_x, indices_y) = map(lambda t: t.chunk(2, dim = 3), (scores, indices))
all_topk = self.topk ** 2
all_scores = rearrange((
rearrange(scores_x, '... k -> ... k 1') +
rearrange(scores_y, '... k -> ... 1 k')
), 'b t h ... -> b t h (...)')
all_indices = rearrange((
rearrange(indices_x, '... k -> ... k 1') * self.num_keys +
rearrange(indices_y, '... k -> ... 1 k')
), 'b t h ... -> b t h (...)')
final_topk, final_indices = all_scores.topk(self.topk, dim=-1)
value_indices = all_indices.gather(-1, final_indices)
# attention
attn = self.gate_activation(final_topk)
attn = self.attn_dropout(attn)
value_indices, attn = map(lambda t: rearrange(t, 'b t h k -> (b t) (h k)'), (value_indices, attn))
# aggregate
if self.concat_values_and_combine:
out = self.values(value_indices)
else:
out = self.values(value_indices, per_sample_weights = attn)
out = self.value_dropout(out)
# maybe layernorm the output
out = self.output_norm(out)
return rearrange(out, '(b t) d -> b t d', b = b)
UltraMem
来自ULTRA-SPARSE MEMORY NETWORK,字节发这个时候吹“有效解决了MoE推理时高额的访存问题,推理速度较MoE架构提升2-6倍,推理成本最高可降低83%”,猛地一看以为把DeepSeekMoE又给提升了2-6倍,可实际上是下面这个MoE的paper。UltraMem的思路实际上是对PKM思路的一种改进,但字节并没有公布源代码,也不知道他们家的智障豆包用上了没,先摘录一些核心想法,等代码出了再仔细拜读。
为了解决drawback1和drawback3,把PQ改成了下面的TDQKR,一种基于SVD分解的方法:
MoE
这个MoE不同于MoE架构LLM中的MoE,而是对PosWiseFFN的改进,来自于Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity,以下是论文中的截图,看一眼就知道大致的思路:
附录:
- https://mp.weixin.qq.com/s/BPGbzAQ5AKPj7yqrOCCuGQ?token=2117558689&lang=zh_CN
- https://team.doubao.com/zh/publication/ultra-sparse-memory-network?view_from=research
- https://www.cls.cn/detail/1940788