模型图
代码及分析
不需要考虑任何mask问题,直接通过矩阵计算求出下三角矩阵每个元素的值即可,不需要额外添加mask之类的。
temperature=0(确定性)的时候,模型推理每次都取概率最大的(从而导致同样的输入prompt会有完全相同的输出);否则根据概率分布来挑选,即有一定概率输出和前一个字不搭配的字
多头注意力机制有两种理解,实现和效果也不同,一种是将embedding维切分成head_num个m=embedding/head_num维,产生m组不同的qkv(维度也不再是embedding)分别对切分后的m组向量做注意力(一一对应)最后拼起来还原为embedding维,另一种理解是,不需要对embedding切分,而是用正常embedding维的大小的qkv分别做注意力,最后也是拼接起来,这时候embedding维拓展成head_num*embedding维,只需要再用一个矩阵线性变换为embedding维即可。
本文的代码基于第一种理解
kv-cache
主要思想就是通过缓存之前的注意力结果以及只挑当前时间步的Q来计算注意力减少计算量。每个时间步t只需要算当前词Wt对W1~t的注意力,因此只需要用当前词的Qt和K1~t以及V1~t即可求出最终的下一个向量的表示,此时再拼接到之前W1~t-1的向量表示即可
作者也没使用任何矩阵运算库,直接就是根据矩阵的定义行列向量点乘求和求出且只考虑矩阵和向量之间的矩阵乘法