Gemma模型报告中提到的几个点进行代码细节解读一下:
(1)Embedding层共享参数
(2)输入输出层均进行RMSNorm
Embedding层共享参数
共享embedding的权重给最后的llm_head层。是词嵌入层的共享,与旋转位置编码无任何关系。早期BERT、GPT等都使用了此操作。代码实现也非常简单,即nn.parameters定义weight,使用F.embedding,最后llm_head层传入weight参数。
Gemma源码中主要是下面这一段:
如下图所示,上面是不共享参数的流程,下面是共享参数的流程:
输入输出层均进行RMSNorm
也很简单,流程图如下,其实这种做法是现在公认的了,LLAMA中也是两次RMSNorm。
源码中主要是下面部分: