-
2017年Google在论文《Attention is All You Need》中提出了Transformer模型,并成功应用到NLP领域。该模型完全基于自注意力机制Attention mechanism实现,弥补了传统的RNN模型的不足。宏观层面,Transformer可以看成是一个黑箱操作的序列到序列(seq2seq)模型。例如,在机器翻译中,输入一种语言,经Transformer输出翻译后的另一种语言。拆开这个黑箱,可以看到模型本质就是一个Encoders-Decoders结构。Self - Attention是Transformer中最核心的思想。
-
每个Encoders中分别由6层Encoder组成。(所有Encoder结构完全相同,但是训练参数不同,每个参数是独立训练的,循环执行6次Encode,而不是只训练了一个Encoder然后复制5份)。Decoders同理。
-
Encoders端:经过词向量层(Input Embedding)和位置编码层(Positional Encoding),得到最终输入,流经自注意力层(Multi-Head Attention)、残差和层归一化(Add&Norm)、前馈神经网络层(Feed Forward)、残差和层归一化(Add&Norm),得到编码端的输出(后续会和解码端进行交互)。
-
-
Decoders端:经过词向量层(Output Embedding)和位置编码层(Positional Encoding),得到最终输入,流经掩码自注意力层(Masked Multi-Head Attention,把当前词之后的词全部mask掉)、残差和层归一化(Add&Norm)、交互注意力层(Multi-Head Attention,把编码端的输出和解码端的信息进行交互,Q矩阵来自解码端,K、V矩阵来自编码端的输出)、残差和层归一化(Add&Norm)、前馈神经网络层(Feed Forward)、残差和层归一化(Add&Norm),得到解码端的输出。查询、键、值,看到这中文的意思,还是迷迷糊糊的。我们来举个例子:小明想在b站搜索深度学习,他把深度学习四个字输入到搜索栏,按下搜索键。搜索引擎就会将他的查询query映射到数据库中相关的标签key,如吴恩达、神经网络等等,然后向小明展示最匹配的结果value。
-
自注意力机制小尝
-
在进行Self - Attention之前,我们首先定义3个1×4的input。 pytorch代码如下:
-
import torch x = [ [1, 0, 1, 0], # input 1 [0, 2, 0, 2], # input 2 [1, 1, 1, 1] # input 3 ] x = torch.tensor(x, dtype=torch.float32)
-
每个input和三个权重矩阵分别相乘会得到三个新的矩阵,分别是key,query,value。我们已经令input的shape为1×4,key、query、value的shape为1×3,因此可以推出与input相乘的权重矩阵的shape为4×3。 代码如下:
-
w_key = [ [0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 1, 0] ] w_query = [ [1, 0, 1], [1, 0, 0], [0, 0, 1], [0, 1, 1] ] w_value = [ [0, 2, 0], [0, 3, 0], [1, 0, 3], [1, 1, 0] ] w_key = torch.tensor(w_key, dtype=torch.float32) w_query = torch.tensor(w_query, dtype=torch.float32) w_value = torch.tensor(w_value, dtype=torch.float32) print("Weights for key: \n", w_key) print("Weights for query: \n", w_query) print("Weights for value: \n", w_value)
-
现在我们计算key,query和value矩阵的值,计算的过程也很简单,运用矩阵乘法即可: key = input * w_key; query = input * w_query; value = input * w_value;
-
keys = x @ w_key querys = x @ w_query values = x @ w_value print("Keys: \n", keys) # tensor([[0., 1., 1.], # [4., 4., 0.], # [2., 3., 1.]]) print("Querys: \n", querys) # tensor([[1., 0., 2.], # [2., 2., 2.], # [2., 1., 3.]]) print("Values: \n", values) # tensor([[1., 2., 3.], # [2., 8., 0.], # [2., 6., 3.]])
-
为了获得input1的注意力分数(attention scores),我们将 input1 的query与 input1、2、3的 key 的转置分别作点积,得到3个attention scores。 同理,我们也可以得到input2和input3的attention scores。
-
attn_scores = querys @ keys.T print(attn_scores) # tensor([[ 2., 4., 4.], # attention scores from Query 1 # [ 4., 16., 12.], # attention scores from Query 2 # [ 4., 12., 10.]]) # attention scores from Query 3
-
上一步得到了attention scores矩阵后,我们对attention scores矩阵作softmax计算。softmax的作用为归一化,使得其中各项相加后为1。这样做的好处是凸显矩阵中最大的值并抑制远低于最大值的其他分量。
-
from torch.nn.functional import softmax attn_scores_softmax = softmax(attn_scores, dim=-1) print(attn_scores_softmax) # tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01], # [6.0337e-06, 9.8201e-01, 1.7986e-02], # [2.9539e-04, 8.8054e-01, 1.1917e-01]])
-
每个score乘以其对应的value得到3个alignment vectors。我们将它们称为weighted values(加权值)。
-
weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None] print(weighted_values)
-
每个input生成3个weighed values,我们将这3个weighted values相加,得到output。图中一共有3个input,所以最终生成3个output。动图轻松理解Self-Attention(自注意力机制) - 知乎 (zhihu.com)
-
outputs = weighted_values.sum(dim=0) print(outputs)
-
-
输入Inputs维度是[batch size,sequence length],经Word2Vec,转换为计算机可以识别的Input Embedding,论文中每个词对应一个512维度的向量,维度是[batch_size,sequence_length,embedding_dimmension]。batch size指的是句子数,sequence length指的是输入的句子中最长的句子的词数,embedding_dimmension是词向量长度。
-
首先对输入的文字进行Word Embedding处理,每个字(词)用一个连续型向量表示(可以定义的是4维向量),称为词向量。这样一个句子,也就是嵌入后的输入向量input Embedding就可以用一个矩阵表示(4*4维,序列长度为4,每个字用 4 维向量表示)。input Embedding加上位置信息得到编码器的输入 α \alpha α 矩阵。
- 为什么需要在input Embedding加上位置信息?与RNN相比,RNN是一个字一个字输入,自然可以保留每个字的顺序关系信息,而Transformer使用的是自注意力机制来提取信息,一个句子中的每个字/词是并行计算,虽然处理每个字的时候考虑到了所有字对其的影响,但是并没有考虑到各个字相互之间的位置信息,也就是上下文。所以需要引入位置信息。Transformer中使用Positional Encoding表示每个字/词的位置信息。定义如下:
-
pos:表示一句话中某个字的实际位置。表示第1个字,pos=0;表示第2个字,pos=1。
-
P E p o s , i PE_{pos,i} PEpos,i :表示pos位置处的字的Positional Encoding向量,该向量可以用来给句子中每个字提供位置信息。换句话说,就是我们通过注入每个字的位置信息,增强了模型输入。
-
d m o d e l d_{model} dmodel :表示词向量的维度。论文中的512。
-
i 表示词向量的第i+1维度。例如 p 12 p^{12} p12 表示第 2 个字的第 3 维度,i 为奇数,使用cos函数;i为偶数,使用sin函数。(从0开始计数)
- 例如 p 22 p^{22} p22 表示第3个字(pos=2)的第3维度(i=2),对应的值就是 P E 2 , 2 = s i n ( 1 100 0 2 ∗ 2 512 ) PE_{2,2}=sin(\frac{1}{1000^{\frac{2*2}{512}}}) PE2,2=sin(10005122∗21)
- 例如 p 13 p^{13} p13 表示第2个字(pos=1)的第4维度(i=3),对应的值就是 P E 1 , 3 = c o s ( 1 100 0 1 ∗ 3 512 ) PE_{1,3}=cos(\frac{1}{1000^{\frac{1*3}{512}}}) PE1,3=cos(10005121∗31)
-
-
Word Embedding在Pytorch中通常用nn.Embedding实现。
-
class Embeddings(nn.Module): """ 类的初始化 :param d_model: 词向量维度,512 :param vocab: 当前语言的词表大小 """ def __init__(self, d_model, vocab): super(Embeddings, self).__init__() # 调用nn.Embedding预定义层,获得实例化词嵌入对象self.lut self.lut = nn.Embedding(vocab, d_model) self.d_model = d_model #表示词向量维度 def forward(self, x): """ Embedding层的前向传播 参数x:输入给模型的单词文本通过此表映射后的one-hot向量 x传给self.lut,得到形状为(batch_size, sequence_length, d_model)的张量,与self.d_model相乘, 以保持不同维度间的方差一致性,及在训练过程中稳定梯度 """ return self.lut(x) * math.sqrt(self.d_model)
-
-
Positional Encoding
-
class PositionalEncoding(nn.Module):"""实现Positional Encoding功能""" def __init__(self, d_model, dropout=0.1, max_len=5000): """ 位置编码器的初始化函数 :param d_model: 词向量的维度,与输入序列的特征维度相同,512 :param dropout: 置零比率 :param max_len: 句子最大长度,5000 """ super(PositionalEncoding, self).__init__() # 初始化一个nn.Dropout层,设置给定的dropout比例 self.dropout = nn.Dropout(p=dropout) # 初始化一个位置编码矩阵 # (5000,512)矩阵,保持每个位置的位置编码,一共5000个位置,每个位置用一个512维度向量来表示其位置编码 pe = torch.zeros(max_len, d_model) # 偶数和奇数在公式上有一个共同部分,使用log函数把次方拿下来,方便计算 # position表示的是字词在句子中的索引,如max_len是128,那么索引就是从0,1,2,...,127 # 论文中d_model是512,2i符号中i从0取到255,那么2i对应取值就是0,2,4...510 # (5000) -> (5000,1) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # 计算用于控制正余弦的系数,确保不同频率成分在d_model维空间内均匀分布 div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # 根据位置和div_term计算正弦和余弦值,分别赋值给pe的偶数列和奇数列 pe[:, 0::2] = torch.sin(position * div_term) # 从0开始到最后面,补长为2,其实代表的就是偶数位置 pe[:, 1::2] = torch.cos(position * div_term) # 从1开始到最后面,补长为2,其实代表的就是奇数位置 # 上面代码获取之后得到的pe:[max_len * d_model] # 下面这个代码之后得到的pe形状是:[1 * max_len * d_model] # 多增加1维,是为了适应batch_size # (5000, 512) -> (1, 5000, 512) pe = pe.unsqueeze(0) # 将计算好的位置编码矩阵注册为模块缓冲区(buffer),这意味着它将成为模块的一部分并随模型保存与加载,但不会被视为模型参数参与反向传播 self.register_buffer('pe', pe) def forward(self, x): """ x: [seq_len, batch_size, d_model] 经过词向量的输入 """ x = x + self.pe[:, :x.size(1)].clone().detach() # 经过词向量的输入与位置编码相加 # Dropout层会按照设定的比例随机“丢弃”(置零)一部分位置编码与词向量相加后的元素, # 以此引入正则化效果,防止模型过拟合 return self.dropout(x)
-
-
自注意力机制(Self Attention Mechanism)
-
对于机器来说其实就是赋予多少权重(比如0-1之间的小数),越重要的地方或者越相关的地方赋予的权重越高。注意力机制的实现思想是先计算第1个字与句中每个字的注意力分数(包括第1个字),再用求得的注意力分数与对应字的信息相乘,并相加,得到的结果就是第1个字与句子中所有字的加权和,第2个字、第3个字…以此类推。
-
如上图所示,以包含位置信息的词向量 α i \alpha^i αi 作为Self Attention Mechanism的输入。 α i \alpha^i αi 即为一句话中第i+1个词的词向量。 α i \alpha^i αi 分别乘以 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV 三个矩阵,得到 q i , k i , v i q^i,k^i,v^i qi,ki,vi。其中,q是查询向量,k是词的“被查”向量,v是词的“内容”向量
-
下面计算每个字的注意力信息。以第1个字与句子中所有字的注意力信息为例,首先 q 0 q^0 q0 分别乘以 k 0 , k 1 , k 2 , k 3 k^0,k^1,k^2,k^3 k0,k1,k2,k3,得到4个常数注意力值 α 00 , α 01 , α 02 , α 03 \alpha_{00},\alpha_{01},\alpha_{02},\alpha_{03} α00,α01,α02,α03,再对其分别经过Softmax归一化,得到第1个字与所有字的注意力分数 α 00 ^ , α 01 ^ , α 02 ^ , α 03 ^ \hat{\alpha_{00}},\hat{\alpha_{01}},\hat{\alpha_{02}},\hat{\alpha_{03}} α00^,α01^,α02^,α03^,它们的和为1,最后再用注意力分数与对应的字信息 v 0 , v 1 , v 2 , v 3 v^0,v^1,v^2,v^3 v0,v1,v2,v3 相乘,即可得到第1个字与句中所有字的加权信息。加权和:
-
b 0 = α 00 ^ ∗ v 0 + α 01 ^ ∗ v 1 + α 02 ^ ∗ v 2 + α 03 ^ ∗ v 3 b^0=\hat{\alpha_{00}}*v^0+\hat{\alpha_{01}}*v^1+\hat{\alpha_{02}}*v^2+\hat{\alpha_{03}}*v^3 b0=α00^∗v0+α01^∗v1+α02^∗v2+α03^∗v3
-
第2、3、4个字与句子中所有字的加权和 b 1 , b 2 , b 3 b^1,b^2,b^3 b1,b2,b3 以此类推。
-
-
实际中计算机为了加速计算,通常采用矩阵计算思想。如下图所示,首先词向量矩阵 α i \alpha^i αi 分别乘以 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV 三个矩阵,得到 q i , k i , v i q^i,k^i,v^i qi,ki,vi。其中 W Q , W K , W V W^Q,W^K,W^V WQ,WK,WV 矩阵的维度是词向量长度,词向量长度。
-
再用q矩阵乘以k矩阵得到注意力值矩阵 α \alpha α,如下图所示。其中, α = q ∗ k T d \alpha=\frac{q*k^T}{\sqrt d} α=dq∗kT, k T k^T kT:k矩阵的转置。d:词向量长度,论文中是512。
-
然后,矩阵 α \alpha α 每一行,经过Softmax计算出注意力分数矩阵 α i j ^ \hat{\alpha_{ij}} αij^。矩阵 α i j ^ \hat{\alpha_{ij}} αij^ 每一行的分数值和为1。最后,用注意力分数矩阵 α i j ^ \hat{\alpha_{ij}} αij^ 乘以 v j v^j vj 矩阵得到输出矩阵 b i b^i bi。
-
class ScaledDotProductAttention(nn.Module): """ Scaled Dot-Product Attention """ def __init__(self, scale_factor, dropout=0.0): super().__init__() self.scale_factor = scale_factor #dropout用于防止过拟合,在前向传播的过程中,让某个神经元的激活值以一定的概率停止工作 self.dropout = nn.Dropout(dropout) def forward(self, q, k, v, mask=None): # batch_size: 批量大小 # len_q,len_k,len_v: 序列长度 在这里他们都相等 # n_head: 多头注意力,论文中默认为8 # d_k,d_v: k v 的dim(维度) 默认都是64 # 此时q的shape为(batch_size, n_head, len_q, d_k) (batch_size, 8, len_q, 64) # 此时k的shape为(batch_size, n_head, len_k, d_k) (batch_size, 8, len_k, 64) # 此时v的shape为(batch_size, n_head, len_k, d_v) (batch_size, 8, len_k, 64) # q先除以self.scale_factor,再乘以k的转置(交换最后两个维度(这样才可以进行矩阵相乘))。 # attn的shape为(batch_size, n_head, len_q, len_k) attn = torch.matmul(q / self.scale_factor, k.transpose(2, 3)) if mask is not None: """ 用-1e9代替0 -1e9是一个很大的负数 经过softmax之后接近0 # 其一:去除掉各种padding在训练过程中的影响 # 其二,将输入进行遮盖,避免decoder看到后面要预测的东西。(只用在decoder中) """ scores = scores.masked_fill(mask == 0, -1e9) # 先在attn的最后一个维度做softmax 再dropout 得到注意力分数 attn = self.dropout(torch.softmax(attn, dim=-1)) # 最后attn与v矩阵相乘 # output的shape为(batch_size, 8, len_q, 64) output = torch.matmul(attn, v) # 返回 output和注意力分数 return output, attn
-
多头注意力机制(Multi-Head Attention )
-
多头注意力机制即就是把上述的 q i , k i , v i q^i,k^i,v^i qi,ki,vi 三个矩阵从特征维度(词向量长度)上拆分为形状相同的小矩阵,如下图所示,拆分为2个形状相同的小矩阵,即为二头注意力。本例中,句子长度为4,词向量维度是4,小矩阵维度即为[4,4/2=2]。接下来以上述方式计算2个b矩阵,再将每个Head Attention计算出来的b矩阵拼接,即为最终的注意力矩阵。注:论文中句子长度为5,词向量维度是512,将 q i , k i , v i q^i,k^i,v^i qi,ki,vi 三个矩阵拆分成了8个形状相同的小矩阵,也就是8头注意力,小矩阵维度为[5,512/8=64]。
-
其中,输入 α \alpha α 与最后输出b的形状相同。
-
class MultiHeadAttention(nn.Module): """ Multi-Head Attention module """ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): # 论文中这里的n_head, d_model, d_k, d_v分别默认为8, 512, 64, 64 # q k v先经过不同的线性层,再用ScaledDotProductAttention,最后再经过一个线性层 super().__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) self.fc = nn.Linear(n_head * d_v, d_model, bias=False) self.attention = ScaledDotProductAttention(scale_factor=d_k ** 0.5) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)# 默认对最后一个维度初始化 def forward(self, q, k, v, mask=None): # q, k, v初次输入为含位置信息的嵌入矩阵X,由于要堆叠N次,后面的输入则是上个多头的输出 # q, k, v:batch_size * seq_num * d_model d_k, d_v, n_head = self.d_k, self.d_v, self.n_head # len_q, len_k, len_v 为输入的序列长度 # batch_size为batch_size batch_size, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) # 用作残差连接 residual = q # Pass through the pre-attention projection: b x lq x (n*dv) # Separate different heads: b x lq x n x dv # q k v 分别经过一个线性层再改变维度 # 由(batch_size, len_q, n_head*d_k) => (batch_size, len_q, n_head, d_k) # (batch_size, len_q, 8*64) => (batch_size, len_q, 8, 64) q = self.layer_norm(q) k = self.layer_norm(k) v = self.layer_norm(v) # 与q,k,v相关矩阵相乘,得到相应的q,k,v向量,d_model=n_head * d_k q = self.w_qs(q).view(batch_size, len_q, n_head, d_k) k = self.w_ks(k).view(batch_size, len_k, n_head, d_k) v = self.w_vs(v).view(batch_size, len_v, n_head, d_v) # Transpose for attention dot product: b x n x lq x dv # 交换维度做attention # (batch_size, len_q, n_head, d_k) => (batch_size, n_head, len_q, d_k) # (batch_size, len_q, 8, 64) => (batch_size, 8, len_q, 64) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) if mask is not None: # 为head增加一个维度 mask = mask.unsqueeze(1) # For head axis broadcasting. # 输出的q为Softmax(QK/d + (1-S)σ)V, attn 为QK/D q, attn = self.attention(q, k, v, mask=mask) # Transpose to move the head dimension back: b x lq x n x dv # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) # (batch_size, 8, len_k, 64) => (batch_size, len_k, 8, 64) => (batch_size, len_k, 512) q = q.transpose(1, 2).contiguous().view(batch_size, len_q, -1) # 经过fc和dropout q = self.dropout(self.fc(q)) # 残差连接 论文中的Add & Norm中的Add q += residual # 论文中的Add & Norm中的Norm q = self.layer_norm(q) # q的shape为(batch_size, len_q, 512) # attn的shape为(batch_size, n_head, len_q, len_k) return q, attn
-
-
Add采用残差神经网络思想,也就是Multi-Head Attention的输入 α \alpha α 矩阵直接与输出 b b b 相加,这样可以让网络训练的更深,得到 b ˉ \bar b bˉ 矩阵,再经过Layer normalization归一化处理,加快训练速度,使得 b ˉ \bar b bˉ 的每一行也就是每个句子归一化为标准正态分布,输出为 b ^ \hat b b^。公式如下:
-
均值: μ i = 1 s ∑ j = 1 s b i j μ_i=\frac1s\sum_{j=1}^sb_{ij} μi=s1∑j=1sbij,其中,s是 b i ˉ \bar{b_i} biˉ 的长度;
-
方差: δ i = 1 s ∑ j = 0 s ( b i j − μ i ) 2 \delta_i=\frac1s\sum_{j=0}^s(b_{ij}-μ_i)^2 δi=s1∑j=0s(bij−μi)2;
-
归一化: l a y e r n o r m ( x ) = b i j − u i δ i + η ∗ γ + β layernorm(x)=\frac{b_{ij}-\,u_i}{\delta_i+\eta}*\gamma+\beta layernorm(x)=δi+ηbij−ui∗γ+β。
-
class LayerNorm(nn.Module): def __init__(self, d_model, eps=1e-12): super().__init__() # 初始化尺度参数gamma self.gamma = nn.Parameter(torch.ones(d_model)) # 初始化偏差参数beta self.beta = nn.Parameter(torch.zeros(d_model)) # 设置一个小常数,防止除0 self.eps = eps def forward(self, x): # 计算均值 mean = x.mean(-1, keepdim=True) # 计算方差,unbiased=False时,方差的计算使用n而不是n-1做分母 var = x.var(-1, unbiased=False, keepdim=True) # 归一化计算 out = (x - mean) / torch.sqrt(var + self.eps) out = self.gamma * out + self.beta return out
-
-
将Add & Layer normalization输出 b ˉ \bar b bˉ,经过两个全连接层(第一层的激活函数为 Relu,第二层不使用激活函数),再经过Add & Layer normalization得到最后输出矩阵O。
-
class PoswiseFeedForwardNet(nn.Module): def __init__(self): super(PoswiseFeedForwardNet, self).__init__() self.fc = nn.Sequential( nn.Linear(d_model, d_ff, bias=False), nn.ReLU(), nn.Linear(d_ff, d_model, bias=False)) def forward(self, inputs): # inputs: [batch_size, seq_len, d_model] residual = inputs output = self.fc(inputs) return nn.LayerNorm(d_model).cuda()(output + residual)# [batch_size, seq_len, d_model]
-
-
Mask句子中没有实际意义的占位符,例如’我 是 学 生 P’ ,P对应句子没有实际意义,所以需要被Mask,Encoder_input 和Decoder_input占位符都需要被Mask。
-
# seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len] def get_attn_pad_mask(seq_q, seq_k): batch_size, len_q = seq_q.size() batch_size, len_k = seq_k.size() # eq(zero) is PAD token pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking # 扩展成多维度 return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
-
-
EncoderLayer代码实现
-
class EncoderLayer(nn.Module): def __init__(self): super(EncoderLayer, self).__init__() self.enc_self_attn = MultiHeadAttention()# 多头注意力机制 self.pos_ffn = PoswiseFeedForwardNet()# 前馈神经网络 def forward(self, enc_inputs, enc_self_attn_mask): # enc_inputs: [batch_size, src_len, d_model] # 输入3个enc_inputs分别与W_q、W_k、W_v相乘得到Q、K、V # enc_self_attn_mask: [batch_size, src_len, src_len] enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_outputs: [batch_size, src_len, d_model], # attn: [batch_size, n_heads, src_len, src_len] enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model] return enc_outputs, attn
-
-
Encoder代码实现
-
class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.src_emb = nn.Embedding(src_vocab_size, d_model) # 把字转换字向量 self.pos_emb = PositionalEncoding(d_model) # 加入位置信息 self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) def forward(self, enc_inputs): # enc_inputs: [batch_size, src_len] # 1. 中文字索引进行Embedding,转换成512维度的字向量 enc_outputs = self.src_emb(enc_inputs) # enc_outputs: [batch_size, src_len, d_model] # 2. 在字向量上面加上位置信息 enc_outputs = self.pos_emb(enc_outputs) # enc_outputs: [batch_size, src_len, d_model] # 3. Mask掉句子中的占位符号 enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # enc_self_attn_mask: [batch_size, src_len, src_len] enc_self_attns = [] # 4. 通过6层的encoder(上一层的输出作为下一层的输入) for layer in self.layers: enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) # enc_outputs : [batch_size, src_len, d_model], # enc_self_attn : [batch_size, n_heads, src_len, src_len] enc_self_attns.append(enc_self_attn) return enc_outputs, enc_self_attns
-
-
Decoder的输入是最后一个Encoder block的输出。如下图所示,以中文翻译“我是学生”为例,首先将“我是学生”整个句子输入到Encoder中,得到最后一个Encoder block的输出后,将在Decoder中输入"S I am a student",s表示开始。注意这里,“S I am a student"不会一并输入,而是在T0时刻先输入"S”,预测出第一个词"I";再在T1时刻,输入"S"和"I"预测下一个单词"am";同理在T2时刻,输入"S"、“I"和"am”,预测出第三个单词"a",依次把整个句子输入到Decoder,预测出"I am a student E"。
-
这里采用Mask上三角矩阵掩盖了Decoder的输入,T0、T1、T2、T3、T4即为每个时刻的输入。
-
def get_attn_subsequence_mask(seq): # seq: [batch_size, tgt_len] attn_shape = [seq.size(0), seq.size(1), seq.size(1)] subsequence_mask = np.triu(np.ones(attn_shape), k=1) # 生成上三角矩阵,[batch_size, tgt_len, tgt_len] subsequence_mask = torch.from_numpy(subsequence_mask).byte() # [batch_size, tgt_len, tgt_len] return subsequence_mask
-
Masked Multi-Head Attention与Multi-Head Attention类似,只是采用了Mask上三角矩阵,掩盖Decoder的输入。如上所述。Decoder的Multi-Head Attention同样和Encoder的Multi-Head Attention结构一样,只是Decoder的Multi-Head Attention中,K、V矩阵来自Encoder的输出,而Q矩阵来自Masked Multi-Head Attention 的输出。Decoder输出矩阵形状是[句子长度,词向量维度],经过nn.Linear全连接层,再通过softmax函数得到每个词的概率,然后选择概率最大的词作为预测结果。Decoder两次调用MultiHeadAttention时,第一次调用传入的 Q,K,V 的值是相同的,都等于dec_inputs,第二次调用 Q 矩阵是来自Decoder的输入。K,V 两个矩阵是来自Encoder的输出,等于enc_outputs。
-
class DecoderLayer(nn.Module): def __init__(self): super(DecoderLayer, self).__init__() self.dec_self_attn = MultiHeadAttention() self.dec_enc_attn = MultiHeadAttention() self.pos_ffn = PoswiseFeedForwardNet() def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): # dec_inputs: [batch_size, tgt_len, d_model] # enc_outputs: [batch_size, src_len, d_model] # dec_self_attn_mask: [batch_size, tgt_len, tgt_len] # dec_enc_attn_mask: [batch_size, tgt_len, src_len] dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) # dec_outputs: [batch_size, tgt_len, d_model] # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len] dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) # dec_outputs: [batch_size, tgt_len, d_model] # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len] dec_outputs = self.pos_ffn(dec_outputs) # dec_outputs: [batch_size, tgt_len, d_model] return dec_outputs, dec_self_attn, dec_enc_attn
-
-
Decoder代码实现
-
class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_len+1, d_model),freeze=True) self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) def forward(self, dec_inputs, enc_inputs, enc_outputs): # dec_inputs : [batch_size x target_len] # 1. 英文字索引进行Embedding,转换成512维度的字向量,并在字向量上加上位置信息 dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[5,1,2,3,4]])) # 2. Mask掉句子中的占位符号 dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs) dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs) dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0) dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) dec_self_attns, dec_enc_attns = [], [] # 3. 通过6层的decoder(上一层的输出作为下一层的输入) for layer in self.layers: dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask) dec_self_attns.append(dec_self_attn) dec_enc_attns.append(dec_enc_attn) return dec_outputs, dec_self_attns, dec_enc_attns
-
-
Trasformer的整体结构,输入数据先通过Encoder,再通过Decoder,最后把输出进行多分类,分类数为英文字典长度,也就是判断每一个字的概率。
-
class Transformer(nn.Module): def __init__(self): super(Transformer, self).__init__() # 编码器 self.encoder = Encoder() # 解码器 self.decoder = Decoder() # 解码器最后的分类器,分类器的输入d_model是解码层每个token的输出维度大小,需要将其转为词表大小,再计算softmax;计算词出现的概率 self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False) def forward(self, enc_inputs, dec_inputs): # Transformer的两个输入,一个是编码器的输入(源序列),一个是解码器的输入(目标序列) # 其中,enc_inputs的大小应该是 [batch_size, src_len] ; dec_inputs的大小应该是 [batch_size, dec_inputs] """ 源数据输入到encoder之后得到 enc_outputs, enc_self_attns; enc_outputs是需要传给decoder的矩阵,表示源数据的表示特征 enc_self_attns表示单词之间的相关性矩阵 """ enc_outputs, enc_self_attns = self.encoder(enc_inputs) """ decoder的输入数据包括三部分: 1. encoder得到的表示特征enc_outputs、 2. 解码器的输入dec_inputs(目标序列)、 3. 以及enc_inputs """ dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs) """ 将decoder的输出映射到词表大小,最后进行softmax输出即可 """ dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size] return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
-
-
当前,Transformer 已经成为了大语言模型的默认网络结构,为了降低大语言模型的训练成本,一些工作尝试对 Transformer 的计算成本进行优化,比如降低注意力运算的时间成本或者显存占用等。 Flash Attention,一种优化的注意力算法手撕Flash Attention!原理解析及代码实现 (qq.com)。Flash Attention 论文链接如下:[2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arxiv.org)
-
Transformer 包括编码器和解码器两部分,由于当前主流的大语言模型几乎都基于只含解码器而不含编码器的仅解码器 (decoder-only) 模型,该解码器通过多个解码器层堆叠而成,每层包含自注意力层、前馈神经网络、层归一化、残差连接等组件。其中,自注意力层接收一个特征序列作为输入,并将该序列输入作为查询 (Query, 下文简称 Q)、键 (Key, 下文简称 K) 和值 (Value, 下文简称 V),使用缩放点积 (Scaled-dot Production) 来计算 Q 和 K 之间的注意力权重矩阵,然后再通过注意力权重和 V 来计算自注意力层的输出。
-
Transformer 的主要组成部分为 attention,因此优化 Transformer 重点在于优化 attention 的计算。那么,attention 为什么需要优化呢?或者说,注意力机制为什么慢?[此处的“快慢”是相对而言的。严格意义上来说,相比于传统的 RNN,Transformer 中的 attention 可以并行地处理序列所有位置的信息(RNN 只能串行处理),因此计算效率并不低,但是仍然有可以进一步改进的空间]
-
对于科学计算程序而言,按照算数运算和内存读取各自所花的时间比例,科学计算通常分为计算密集型 (compute-bound) 和内存密集型 (memory-bound) 两类。其中,计算密集型运算的时间瓶颈主要在于算数计算,比如大型矩阵的相乘等,而内存密集型运算的时间瓶颈主要在于内存的读写时间,比如批归一化、层归一化等等。我们可以从计算和内存两方面来分析“attention为什么慢”这个问题,分别对应于时间复杂度和空间复杂度两个方面。
- 时间复杂度方面,attention 需要对矩阵 Q 和矩阵 K 的转置做乘法来得到注意力权重矩阵。不考虑 batch 维度,假设矩阵 Q 和 K 的尺寸都为 (n , dim),一个 (n , dim) 和 (dim, n) 的矩阵相乘的时间复杂度是序列长度 n 的平方级,即 attention 的时间复杂度为 O ( n 2 ) O(n^2) O(n2) 。当序列较长 (即 n 较大) 时, attention 的计算非常耗时。
- 空间复杂度方面, attention 的计算过程需要存储 S ( S = Q K T ) S(S=QK^T) S(S=QKT) 和 P ( P = s o f t m a x ( S ) ) P(P=softmax(S)) P(P=softmax(S)) 这两个尺寸均为 (n,n) 的矩阵,因此 attention **运算的空间复杂度也为 O ( n 2 ) O(n^2) O(n2) ** 。
-
为了对 attention 的内存读取时间有更清晰的感知,这里简单介绍 GPU 的内存层级。
-
GPU 的内存可以分为 HBM 和 SRAM 两部分。例如,A100 GPU 具有 40-80 GB 的高带宽内存 (上图中的 HBM,即我们平时说的“显存”),带宽为 1.5-2.0 TB/s,并且每个流式多处理器都有 192 KB 的片上 SRAM,带宽约为 19 TB/s。片上 SRAM 比 HBM 快一个数量级,但容量要小很多个数量级。在 GPU 运算之前,数据和模型先从 CPU 的内存(上图中的 DRAM)移动到 GPU 的 HBM,然后再从 HBM 移动到 GPU 的 SRAM,CUDA kernel 在 SRAM 中对这些数据进行运算,运算完毕后将运算结果再从 SRAM 移动到 HBM。
-
通过前面的空间复杂度分析,attention 运算需要占据的显存空间随着序列长度 n 的增长呈平方级增长。由于运算需要在 GPU 的 SRAM上 完成,这一过程需要不停地在 HBM 和 SRAM 之间交换数据,因此会导致大量的时间都消耗在 SRAM 和 HBM 之间的数据的换入换出上。
-
如何提高 attention 的计算效率
- 降低 attention 的计算复杂度:计算复杂度方面,一些工作尝试提出近似的 attention 算法,来降低 attention 的理论上的计算复杂度。主要可以分为稀疏 (sparse) 估计、低秩 (low-rank) 估计等。
- 其中,稀疏估计的基本思想是通过一个稀疏的矩阵来估计完整的、稠密 (dense)的注意力矩阵。比如,Reformer 对 Q 和 K 进行局部敏感哈希 (Local Sensitive Hashing, LSH),只对同一个桶 (bucket) 中的 Q 和 V 计算 attention,将 attention 的时间复杂度从 O ( n 2 ) O(n^2) O(n2) 降低到 O ( n l o g ( n ) ) O(nlog(n)) O(nlog(n)),而 Routing Transformer Q 和 K 进行聚类,只对同一个簇 (cluster) 中的 Q 和 K 计算attention,从而将attention的时间复杂度从 O ( n 2 ) O(n^2) O(n2) 降低到 O ( n 1.5 ) O(n^{1.5}) O(n1.5)。
- 低秩估计的基本思想通过一个低秩 (low-rank) 矩阵来估计注意力矩阵。比如, linear transformer引入核函数 ϕ ( x ) \phi(x) ϕ(x),将 s c o r e = s o f t m a x ( Q K T ) V score=softmax(QK^T)V score=softmax(QKT)V 形式化成$ score=\phi{Q}(\phi{K}^TV) $,来解耦开 softmax 运算中的 Q 和 K 。这样操作之后,可以先计算 ϕ K T V \phi{K}^TV ϕKTV ,然后再和 ϕ Q \phi Q ϕQ 相乘,由于 ϕ K T \phi K^T ϕKT 的尺寸为 (d,n) , ϕ Q \phi{Q} ϕQ 和 V 的尺寸为 ,因此 ϕ Q ( ϕ K T V ) \phi{Q}(\phi{K}^TV) ϕQ(ϕKTV) 的时间复杂度为 O ( n ) O(n) O(n) (简要推导: n ∗ d ∗ ( ( d ∗ n ) ∗ n ∗ d ) − > ( n ∗ d ) ∗ ( d ∗ d ) − > ( n ∗ d ) n*d*((d*n)*{n*d})->(n*d)*(d*d)->(n*d) n∗d∗((d∗n)∗n∗d)−>(n∗d)∗(d∗d)−>(n∗d) ,时间复杂度为 O ( n ) O(n) O(n)。
- 降低attention的空间复杂度:空间复杂度方面,这方面工作的基本思路是降低 attention 对于显存的需求,减少 HBM 和 SRAM 之间的换入换出,进而减少 attention 运算的时间消耗。
- 值得一提的是,“减少 attention 对于显存的需求”和“减少 HBM 和 SRAM 之间的换入换出”这两者之间并不等价,前者重点在于减少显存消耗,比如 memory-efficient attention,而后者重在降低数据交换的时间成本,比如 DATA MOVEMENT IS ALL YOU NEED: A CASE STUDY ON OPTIMIZING TRANSFORMERS。
- 为降低空间复杂度,一种具有代表性的方法是 kernel fusion。kernel fusion 的思想很简单,即将需要通过多个 CUDA kernel 来分步完成的操作融合到一个或者少数几个 CUDA kernel,从而减少数据在HBM和SRAM之间换入换出的次数,进而节省运算时间。
- 比如,我们在 SRAM 上计算 S = Q K T S=QK^T S=QKT ,将矩阵 S 写入到 HBM 中,然后再将矩阵 S 从 HBM 读入到 SRAM 中,计算 P = s o f t m a x ( S ) P=softmax(S) P=softmax(S)。上述两步可以合并在一个 kernel 中完成,即在 SRAM 中计算完 S 之后紧接着就通过 S 计算 P ,这样就可以避免在 HBM 和 SRAM 交换 S。Flash Attention 的做法其实也是 kernel fusion,只是对应的 kernel 专门针对数据的换入换出进行了优化 (IO-aware),尽可能最小化 HBM 和 SRAM 之间的数据交换次数。
- 降低 attention 的计算复杂度:计算复杂度方面,一些工作尝试提出近似的 attention 算法,来降低 attention 的理论上的计算复杂度。主要可以分为稀疏 (sparse) 估计、低秩 (low-rank) 估计等。
-
虽然降低 attention 的计算复杂度在理论上非常具有吸引力,但是在实际应用中仍然存在一些短板,比如以下两点:
- **性能比不上原始 attention。**不论是稀疏估计、低秩估计还是其他,这些方法都采用了某种近似算法来估算注意力权重矩阵,难免会丢失信息。目前主流的还是原始的attention;
- 无法减少内存读取的时间消耗。这些方法只能降低 attention 的计算复杂度,但是无法对 attention 运算过程中的空间复杂度等进行控制,无法减少内存读写带来的时间损耗。
-
和 Transformer 的原始 attention 相比,Flash Attention 有以下三特点:
- 运算速度更快 (Fast);
- 更节省显存 (Memory-Efficient);
- 计算结果相同 (Exact)。
-
得益于 Flash Attention 的这几点特性,自 PyTorch 2.0 开始,Flash Attention 已经被集成到 PyTorch 官方库中,使用者可以直接通过 torch.nn.functional.scaled_dot_product_attention 进行调用。Flash Attention 的动机是尽可能避免大尺寸的注意力权重矩阵在 HBM 和 SRAM 之间的换入换出。具体方法包含两个部分:tiling 和 recomputation。
- tiling 的基本思路:不直接对整个输入序列计算注意力,而是将其分为多个较小的块,逐个对这些块进行计算,增量式地进行 softmax 的规约。规约过程中只需要更新某些中间变量,不需要计算整个注意力权重矩阵。
- recomputation 的基本思路:基于 tiling 技巧,在反向传播过程中不保留整个注意力权重矩阵,而是只保留前向过程中 tiling 的某些中间变量,然后在反向传播过程中重新计算注意力权重矩阵。recomputation 可以看作是一种基于 tiling 的特殊的 gradient checkpointing。
-
基于Tiling技巧的Softmax:Tiling 技巧的核心思想是,尽可能避免对整个序列进行操作,而是通过维护一些中间变量来递推式地完成某些操作,从而减少内存的消耗。
-
为了展示 softmax 运算的详细过程,以下代码没有使用 PyTorch、Numpy 等科学计算库,或者Python原生的 max、min 等归约函数,而仅仅使用 Python 原生的数值运算符对浮点数的列表进行操作。
-
class SoftMax(object): def forward(self, x: List[float]): # loop 1: get the maximum value max_x = -np.inf for t in x: max_x = t if t > max_x else max_x # loop 2: get the accumulative sum of exp(x_i - x_max) accum_exp = 0. for t in x: accum_exp += np.exp(t - max_x) # loop 3: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp` output = [0. for _ in range(len(x))] for i, t in enumerate(x): output[i] = np.exp(t - max_x) / accum_exp return output def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs)
-
从上面的代码可以看出,softmax 函数需要三个循环,第一个循环计算数组的最大值,第二个循环计算 softmax 的分母,第三个循环计算 softmax 输出。使用 tiling 技巧的 softmax 的算法如下。
-
class SoftMaxWithTiling(object): def forward(self, x: List[float]): # loop 1: get the maximum value of x and the accumulated exponential values max_x = -np.inf accum_exp = 0. for t in x: max_x_new = t if t > max_x else max_x accum_exp = np.exp(max_x - max_x_new) * accum_exp + np.exp(t - max_x_new) max_x = max_x_new # loop 2: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp` out = [0. for _ in range(len(x))] for i, t in enumerate(x): out[i] = np.exp(t - max_x) / accum_exp return out
-
单元测试的代码如下
-
class SoftMaxTest(unittest.TestCase): def test_softmax(self): n_test = 10 for _ in range(n_test): n_elem = np.random.randint(1, 11) x = np.random.randn(n_elem).tolist() expected = torch.nn.functional.softmax(torch.tensor(x), dim=-1).tolist() out = SoftMax()(x) self.assertTrue(np.allclose(expected, out, atol=1e-4)) out_with_tiling = SoftMaxWithTiling()(x) self.assertTrue(np.allclose(expected, out_with_tiling, atol=1e-4)) if __name__ == "__main__": unittest.main()
-
该算法和原始的 softmax 最大的区别在于,我们在第一个循环中同时对最大值 m 以及 softmax 的分母 d 进行更新,从而减少了一个循环。在这个循环中,最大值 m 的更新和原始 softmax 相同,但是 softmax 分母的更新却稍有区别。在原始 softmax 中,我们已经通过第一个循环拿到了整个数组最大值,因此在第二个循环中可以直接计算 d,但是在此处,当进行到 for 循环的第j(1≤j<V)步时,我们手头只有子数组 x 1 : j x_{1:j} x1:j 的最大值,此时计算得到的 d 并不等于 d V d_V dV 。为了一直维护正确的 ,我们需要同步地对 d j d_j dj 进行更新。通过 tiling 的方式,softmax 的循环数从三个减到了两个,从而可以降低内存消耗。
-
-
Flash Attention 同样基于上述的tiling技巧实现,但是和上述的 sofmax 有两点不同:
- attention 的计算过程需要对 Q 和 K 进行内积,并且需要维护 attention 的输出矩阵 O;
- 在上述 tiling 形式的 softmax 中,我们的每一步只更新一个元素,但是 Flash Attention 将输入分为多个块,每个块包含多个元素。
-
由于我们无法直接从 Python 层面在 GPU 的 SRAM 和 HBM 之间进行数据交换,因此我们使用
load
和write
方法来分别模拟 HBM -> SRAM 和 SRAM -> HBM 的数据传输过程:-
def load(self, arr, st, ed, step): # Simulate the process that moves data from HBM to SRAM return arr[:, st * step: ed * step] def write(self, arr, val, st, ed, step): # Simulate the process that moves data from SRAM to HBM arr[:, st * step: ed * step] = val
-
-
结合代码来逐步理解该算法:输入: 矩阵 Q , K , V Q,K,V Q,K,V ,GPU 的 SRAM 大小 M。输出: attention 的输出矩阵 O
-
根据 GPU 的 SRAM 大小 M 和 Q 的特征维度 d,设置块大小 B c B_c Bc 和 B r B_r Br。 B c B_c Bc 和 B r B_r Br 的具体数值和使用的 GPU 有关,一个准则是让尽可能多的 GPU 流式处理器 (Streaming Multirocessor, SM) 处于工作状态。由于我们缺少 M 的具体数值,这里我们直接人为预先设定为常数值;
-
初始化输出矩阵 O 、中间量 m 和 l。m 和 l 的尺寸都为 (N,),其中 N 为输入序列的长度。 m和 l 的含义和上一节中基于 tiling 技巧计算 softmax 时的 m 和 l 相同。 i ∈ [ 1 : N ] , m i i\in [1:N],m_i i∈[1:N],mi 表示对 Q : i , K : i , V : i Q_{:i},K_{:i},V_{:i} Q:i,K:i,V:i 进行 attention 运算时 softmax 中分子中的 m a x ( x ) max(x) max(x) 项, l i l_i li 表示 softmax 的分母;
-
out = np.zeros((batch_size, q_len, hidden_size)) l = np.zeros((batch_size, q_len)) m = np.zeros((batch_size, q_len)) m.fill(-np.inf)
-
分别对 Q , K , V Q,K,V Q,K,V 进行分块;
-
Tr = q_len // self.row_block_size# Tr: number of row blocks Tc = k_len // self.col_block_size# Tc: number of column blocks
-
对 O , l , m O,l,m O,l,m 进行分块。此处代码中不需要额外操作; 对 K 和 V 的块进行遍历;
-
for j in range(Tc):
-
将 K 和 V 对应的块 K j K_j Kj 和 V j V_j Vj 从 HBM 加载到 SRAM 中
-
kj = self.load(k, j, j + 1, self.col_block_size) vj = self.load(v, j, j + 1, self.col_block_size)
-
对 Q Q Q 的块进行遍历;
-
for i in range(Tr):
-
将 Q , O , l , m Q,O,l,m Q,O,l,m 对应的块 Q i , O i , l i , m i Q_i,O_i,l_i,m_i Qi,Oi,li,mi 从 HBM 加载到 SRAM 中;
-
qi = self.load(q, i, i + 1, self.row_block_size) oi = self.load(out, i, i + 1, self.row_block_size) mi = self.load(m, i, i + 1, self.row_block_size) li = self.load(l, i, i + 1, self.row_block_size)
-
计算 Q 的块 Q i Q_i Qi 和 K 的块 K i K_i Ki 之间的注意力分数 s i j s_{ij} sij;
-
sij = np.matmul(qi, kj.transpose(0, 2, 1)) / np.sqrt(hidden_size)
-
计算用于更新变量 m 和 l 的相关变量 m i j , P i j , l i j m_{ij},P_{ij},l_{ij} mij,Pij,lij ;
-
mij = np.max(sij, axis=-1) pij = np.exp((sij - mij[..., np.newaxis])) lij = pij.sum(axis=-1)
-
根据上一步计算得到的信息,更新 m 和 l 对应的块;
-
m_new = np.maximum.reduce([mi, mij]) l_new = np.exp(mi - m_new) * li + np.exp(mij - m_new) * lij
-
根据更新之后的 m 和 l,更新输出矩阵 O 对应的块 O i O_i Oi。 O i O_i Oi 的更新方式和 l i l_i li 类似;
-
temp = li[..., np.newaxis] * np.exp(mi - m_new)[..., np.newaxis] * oi + np.exp(mij - m_new)[..., np.newaxis] * np.matmul(pij, vj) temp /= l_new[..., np.newaxis] self.write(out, temp, i, i + 1, self.row_block_size)
-
将更新之后的 m i m_i mi 和 l i l_i li 从 SRAM 写回到 HBM;
-
self.write(m, m_new, i, i + 1, self.row_block_size) self.write(l, l_new, i, i + 1, self.row_block_size)
-
-
Flash Attention 的特点在于尽量减少 GPU 的 HBM 和片上 SRAM 之间的数据交换,从而达到加速运算以及节省显存的目的。Flash Attention 的核心方法是 tiling 和 recomputation。其中 tiling 递推式地计算 softmax,避免了计算整个注意力权重矩阵,而 recomputation 则基于前向运算中的 tiling 保存的某些中间变量,在反向传播时重新计算注意力权重矩阵。
-
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V Attention(Q,K,V)=softmax(dQKT)V
-
存在的问题: Q ∗ K T Q*K^T Q∗KT的结果非常巨大,当sequence length增加时,显存占用以 O ( N 2 ) O(N^2) O(N2) 比例增加。可以想到通过 Tiling 切块的方式减少中间显存占用,但是Softmax这个算子给直接Tiling带来了困难:Softmax的计算逻辑限制了需要等待整行的结果计算完毕后才能继续。FlashAttention在附录里提到他的独创性主要在以下三点:
- Google Research的工作重点在减少整个过程的memory footprint;FlashAttention重点在减少memory reads/writes次数。可以说FlashAttention主要是从GPU block/thread并行度的视角对访存进行了优化。
- Google Research的工作每个block会产出一份中间结果,所有block执行完毕之后,再将他们的中间结果计算获得一个最终结果;FlashAttention则采用类似滑动窗口的方式,第 i 个block会将累积的中间结果传递给第 i + 1 i+1 i+1 个block,也就是说最后一个block计算完毕后,可以保证整行的Softmax逻辑计算正确性。
- Google Research的工作在后向backward的时候做了一些冗余计算,FlashAttention把后向的计算简化了,减少了backward阶段的memory traffic。
-
理解From online softmax to FlashAttention[1805.02867] Online normalizer calculation for softmax (arxiv.org)需要四个步骤
- softmax
- safe softmax
- online softmax Tiling
- FlashAttention Tiling
-
FlashAttention之所以可以省显存(显存开销随Seq length线性增加),是因为解开了softmax以及后面GEMM的行方向依赖,并且通过辅助数组保存的辅助信息re-scale到正确的数值。在Flash Attention之前,也出现过一些加速Transformer计算的方法,这些方法的着眼点是“减少计算量FLOPs”,例如用一个稀疏 attention 做近似计算。但是Flash attention就不一样了,它并没有减少总的计算量,因为它发现:计算慢的卡点不在运算能力,而是在读写速度上。所以它通过降低对显存(HBM)的访问次数来加快整体运算速度,这种方法又被称为O-Awareness。Flash Attention是通过分块计算(tiling)和核函数融合(kernel fusion)来降低对显存的访问。
-
Memory Efficicent,节省显存。在标准attention场景中,forward时我们会计算并保存N*N大小的注意力矩阵;在backward时我们又会读取它做梯度计算,这就给硬件造成了 O ( N 2 ) O(N^2) O(N2) 的存储压力。在Flash Attention中,则巧妙避开了这点,使得存储压力降至O(N)。
-
Exact Attention,精准注意力。之前的办法会采用类似于“稀疏attention”的方法做近似。这样虽然能减少计算量,但算出来的结果并不完全等同于标准 attention下的结果。但是Flash Attention却做到了完全等同于标准attention的实现方式。
-
π \pi π∶硬件算力上限。指的是一个计算平台倾尽全力每秒钟所能完成的浮点运算数。单位是FLOPS or FLOP/s。
-
β∶硬件带宽上限。指的是一个计算平台倾尽全力每秒所能完成的内存交换量。单位是Byte/s。
-
π t \pi_t πt︰某个算法所需的总运算量,单位是FLOPs。下标 t t t 表示total。
-
β t \beta_t βt∶某个算法所需的总数据读取存储量,单位是Byte。下标 t t t 表示total。
-
FLOPS:等同于FLOP/s,表示Floating Point Operations Per Second,即每秒执行的浮点数操作次数,用于衡量硬件计算性能。
-
FLOPs:表示Floating Point Operations,表示某个算法的总计算量(即总浮点运算次数),用于衡量一个算法的复杂度。
-
T c a l T_{cal} Tcal: 对某个算法而言,计算所耗费的时间,单位为 s。其满足 T c a l = π t π T_{cal}=\frac{\pi_t}{\pi} Tcal=ππt。
-
T l o a d T_{load} Tload∶对某个算法而言,读取存储数据所耗费的时间,单位为s。具满足 T l o a d = β t β T_{load}=\frac{\beta_t}{\beta} Tload=ββt。
-
计算限制:当 T c a l > T l o a d T_{cal}>T_{load} Tcal>Tload 时,算法运行的瓶颈在计算上,我们称这种情况为计算限制(math-bound)。此时我们有: π t π > β t β \frac{\pi_t}{\pi}>\frac{\beta_t}{\beta} ππt>ββt,即 π t β t > π β \frac{\pi_t}{\beta_t}>\frac\pi\beta βtπt>βπ。
-
内存限制:当 T c a l < T l o a d T_{cal} <T_{load} Tcal<Tload 时,算法运行的瓶颈在数据读取上,我们称这种情况为内存限制(memory-bound)。此时 π t π < β t β \frac{\pi_t}{\pi}<\frac{\beta_t}{\beta} ππt<ββt,即 π t β t < π β \frac{\pi_t}{\beta_t}<\frac{\pi}{\beta} βtπt<βπ。