Gemma模型论文详解(附源码)

原文链接:Gemma模型论文详解(附源码)

1. 背景介绍

Gemma模型是在2023.2.21号Google新发布的大语言模型, Gemma复用了Gemini相同的技术(Gemini也是Google发布的多模态模型),Gemma这次发布了了2B和7B两个版本的参数,不仅提供了预训练的checkpoints,还提供了用于对话、指令跟随等fine-tune的checkpoints。在QA问答、常识。在11

在这里插入图片描述

2. 模型介绍

2.1 模型结构

Gemma模型使用了transformer decoder结构进行训练,训练的上下文大小为8192个token,模型参数如下:
在这里插入图片描述

相比原始transformer结构的区别:

  • Multi-Query Attention:7B模型使用了multi-head attention,2B模型使用了multi-query attention (with 𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1)。对比llama2中用了group-query attention
    在这里插入图片描述

  • RoPE Embeddings: 不使用绝对位置编码,在每一层前加下RoPE Embedding,同时共享输入与输出层的embedding权重。

  • GeGLU Activations: ReLU的激活替换为GeGLU的激活。对比llama中用了swiglu。

  • Normalizer Location: 在transformer的每一层layer的前后都进行规一化,这里使用RMSNorm做为规一化层。

2.2 训练搭建

Gemma使用TPUv5e进行训练;一个pod中有256块TPUv5e芯片,256块芯片被设计为16X16的2D拓扑;Gemma-7B使用16个pods(4096块卡)进行训练,Gemma-2B使用2个pods(512块卡)。7B模型在一个pod内使用16路模型并行和16路数据并行,2B模型在一个pod内使用256路数据并行。优化器状态使用ZeRO-3进行切分,减少显存占用。在pod外使用类似Pathways的方式减少数据复制的成本。

和Gemini模型训练一样,综合了Jax和Pathways的单控制器single controller编程范式,使用单个python进程编排整个训练; 使用GSPMD partitioner用于训练step的计算,使用XLA compiler减少中间结果的大小。

2.3 训练数据

Gemma 2B和7B分别基于2T和6T个token进行训练,token来源于纯英文的文本,内容包括网页、数学、代码等。使用SentencePiece的tokenizer,字典大小有256K个token。数据过滤使用基于模型的分类器去除有害的、低质量的内容。最后采用类似Gemini的方式进行训练数据的混合,提升高质量数据的占比。

2.4 指令微调(Instruction Tuning)

2B和7B进行有监督微调(SFT)训练中使用混合生成数据和人工标注的prompt文本对,同时进行RLHF训练。在SFT阶段,基于给定的一个prompt,通过测试模型生成多个响应的回答结果,通过一个更大更好的模型进行结果的好坏判断。基于不同的侧重方向(指令跟随/事实/创造性/安全等)构建不同的prompt。使用多种基于LM的自动判断方法,比如chain-of-thought prompting

训练和推理过程中使用相同的数据格式,格式的设计重点在于两点,一个是确定多轮对话中的角色,一个是确定一轮对话的开始结束。对应格式标记和示例的训练数据如下:

在这里插入图片描述
在这里插入图片描述

3. 源码

  • Tensorflow实现的源码在github google-deepmind/gemma中,PyTorch实现的源码在github google/gemma_pytorch。

  • 模型的配置在gemma/config.py文件中, 7B与2B区别主要在于num_hidden_layers/num_attention_heads/num_key_value_heads/hidden_size/intermediate_size

@dataclasses.dataclass
class GemmaConfig:
    # The number of tokens in the vocabulary.
    vocab_size: int = 256000
    # The maximum sequence length that this model might ever be used with.
    max_position_embeddings: int = 8192
    # The number of blocks in the model.
    num_hidden_layers: int = 28
    # The number of attention heads used in the attention layers of the model.
    num_attention_heads: int = 16
    # The number of key-value heads for implementing attention.
    num_key_value_heads: int = 16
    # The hidden size of the model.
    hidden_size: int = 3072
    # The dimension of the MLP representations.
    intermediate_size: int = 24576
    # The number of head dimensions.
    head_dim: int = 256
    # The epsilon used by the rms normalization layers.
    rms_norm_eps: float = 1e-6
    # The dtype of the weights.
    dtype: str = 'bfloat16'
    # Whether a quantized version of the model is used.
    quant: bool = False
    # The path to the model tokenizer.
    tokenizer: Optional[str] = 'tokenizer/tokenizer.model'

    def get_dtype(self) -> Optional[torch.dtype]:
        """Gets the torch dtype from the config dtype string."""
        return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)


def get_config_for_7b() -> GemmaConfig:
    return GemmaConfig()


def get_config_for_2b() -> GemmaConfig:
    return GemmaConfig(
        num_hidden_layers=18,
        num_attention_heads=8,
        num_key_value_heads=1,
        hidden_size=2048,
        intermediate_size=16384
    )
  • 模型定义在gemma/model.py文件中,GemmaDecoderLayer的定义如下:
class GemmaDecoderLayer(nn.Module):

    def __init__(
        self,
        config: gemma_config.GemmaConfig,
    ):
        super().__init__()
        self.self_attn = GemmaAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            quant=config.quant,
        )
        self.mlp = GemmaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant=config.quant,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
  • GeGLU的实现跟llama的swiglu不同,geglu相比glu区是采用了gelu的激活,以下是glu的计算示例图:
    在这里插入图片描述

代码参考如下,代码中self.gate_proj对应上图中的B矩阵,gate相当于 σ ( B ) \sigma(B) σ(B)self.up_proj对应上图中的A矩阵.

class GemmaMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        quant: bool,
    ):
        super().__init__()
        self.gate_proj = Linear(hidden_size, intermediate_size, quant)
        self.up_proj = Linear(hidden_size, intermediate_size, quant)
        self.down_proj = Linear(intermediate_size, hidden_size, quant)

    def forward(self, x):
        gate = self.gate_proj(x)
        gate = F.gelu(gate)
        up = self.up_proj(x)
        fuse = gate * up
        outputs = self.down_proj(fuse)
        return outputs

4. 参考

  • google-deepmind/gemma
  • Gemma 开放模型
  • Gemma: Open Models Based on Gemini Research and Technology
  • gemma-open-models
  • github google/gemma_pytorch
  • github google-deepmind/gemma
  • Grouped Query Attention论文阅读
  • SwiGLU论文阅读

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/405881.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JAVA--File类与IO流

目录 1. java.io.File类的使用 1.1 概述 1.2 构造器 1.3 常用方法 1、获取文件和目录基本信息 2、列出目录的下一级 3、File类的重命名功能 4、判断功能的方法 5、创建、删除功能 2. IO流原理及流的分类 2.1 Java IO原理 2.2 流的分类 2.3 流的API 3. 节点流之一…

微服务学习

一、服务注册发现 服务注册就是维护一个登记簿,它管理系统内所有的服务地址。当新的服务启动后,它会向登记簿交待自己的地址信息。服务的依赖方直接向登记簿要Service Provider地址就行了。当下用于服务注册的工具非常多ZooKeeper,Consul&am…

Jetson Xavier NX 与笔记本网线连接 ,网络共享,ssh连接到vscode

Jetson Xavier NX 与笔记本网线连接 ,网络共享,ssh连接到vscode Jetson Xavier NX桌面版需要连接显示屏、鼠标和键盘,操作起来并不方便,因此常常需要ssh远程连接到本地笔记本电脑,这里介绍一种连接方式,通过…

Linux实验记录:使用PXE+Kickstart无人值守安装服务

前言: 本文是一篇关于Linux系统初学者的实验记录。 参考书籍:《Linux就该这么学》 实验环境: VmwareWorkStation 17——虚拟机软件 RedHatEnterpriseLinux[RHEL]8——红帽操作系统 备注: 实际生产中安装操作系统的工作&…

论文笔记:利用词对比注意增强预训练汉字表征

整理了 ACL2020短文 Enhancing Pre-trained Chinese Character Representation with Word-aligned Att)论文的阅读笔记 背景模型实验 论文地址:论文 背景 近年来,以 BERT 为代表的预训练模型在 NLP 领域取得取得了非常显著的效果。但是&…

谈谈对BFC的理解

文章目录 一、是什么二、触发条件三、应用场景防止margin重叠(塌陷)清除内部浮动自适应多栏布局小结 参考文献 一、是什么 我们在页面布局的时候,经常出现以下情况: 这个元素高度怎么没了?这两栏布局怎么没法自适应&…

28-k8s集群中-StatefulSets控制器(进阶知识)

一、statefullsets控制器概述 1,举例 假如,我们有一个deployment资源,创建了3个nginx的副本,对于nginx来讲,它是不区分启动或者关闭的先后顺序的,也就是“没有特殊状态”的一个服务,也成“无状…

一次有趣的nginx Tcp4层代理转发的试验

nginx主配置文件添加配置: stream {log_format proxy $remote_addr [$time_local] $protocol status:$status bytes_sent:$bytes_sent bytes_received:$bytes_received $session_time upstream_addr:"$upstream_addr" "$upstream_bytes_sent" …

React18源码: React调度中的3种优先级类型和Lane的位运算

优先级类型 React内部对于优先级的管理,贯穿运作流程的4个阶段(从输入到输出),根据其功能的不同,可以分为3种类型: 1 )fiber优先级(LanePriority) 位于 react-reconciler包,也就是L…

基于Java SSM框架实现网络作业提交与批改系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架实现网络作业提交与批改系统演示 摘要 随着互联网时代的到来,同时计算机网络技术高速发展,网络管理运用也变得越来越广泛。因此,建立一个B/S结构的网络作业提交与批改系统,会使网络作业提交与批改系统工作系统化…

JavaScript字符串的常用方法(非常详细)

文章目录 一、操作方法增concat 删改trim()、trimLeft()、trimRight()repeat()padEnd() toLowerCase()、 toUpperCase()查charAt()indexOf()startWith()、includes() 二、转换方法split 三、模板匹配方法match()search()replace() 一、操作方法 我们也可将字符串常用的操作方法…

c编译器学习07:minilisp编译器改造(debug模式支持调试)

问题 原版的minilisp编译器不支持argv输入测试,不方便单步调试。 代码改造目标是既不改变原有程序的各种功能, 又能支持个人习惯的vs单步debug模式。 CMakeLists.txt变更 定义DEBUG宏 解决单步调试源码定位偏差问题 cmake_minimum_required(VERSION …

Puppeteer 使用实战:如何将自己的 CSDN 专栏文章导出并用于 Hexo 博客(二)

文章目录 上一篇效果演示Puppeteer 修改浏览器的默认下载位置控制并发数错误重试并发控制 错误重试源码 上一篇 Puppeteer 使用实战:如何将自己的 CSDN 专栏文章导出并用于 Hexo 博客(一) 效果演示 上一篇实现了一些基本功能,…

8.qt5使用opencv的库函数打开图片

1.配置opencv动态库的环境变量 2.在创建的qt工程中加入如下opencv代码,具体代码如下: 使用opencv库函数显示图片

Mamba详细介绍和RNN、Transformer的架构可视化对比

Transformer体系结构已经成为大型语言模型(llm)成功的主要组成部分。为了进一步改进llm,人们正在研发可能优于Transformer体系结构的新体系结构。其中一种方法是Mamba(一种状态空间模型)。 Mamba: Linear-Time Sequence Modeling with Select…

JVM面试题(1)

1.说一下jvm的主要组成部分,以及作用 类加载器(ClassLoader):将java代码转换成字节码 运行时数据区(Runtime Data Area):将字节码加载到内存中 执行引擎(Execution Engine&#x…

【Flink数据传输(一)】NetworkStack架构概述:实现tm之间的数据交换

文章目录 1. NetworkStack整体架构2. StreamTask内数据流转过程 NetworkStack提供了高效的网络I/O和反压控制 除了各个组件之间进行RPC通信之外,在Flink集群中TaskManager和TaskManager节点之间也会发生数据交换,尤其当用户提交的作业涉及Task实例运行在…

个人博客系列-环境配置-gitee(2)

注册gitee账户 地址:https://gitee.com/ 此步骤省略 新建仓库 执行以下命令 即可 拉取代码 创建目录 mkdir myCode && cd myCode 登录gitee找到项目,点击克隆,拉取代码 连接远程仓库命令 git remote add origin 仓库地址http…

PCIe 5.0 Layout Guide笔记

一、松耦合和紧耦合 松耦合优点是相同走线宽度下电介质更薄,同时对线间距的变化不敏感,提供了更好的阻抗控制;松耦合缺点是需要更大的区域进行绕线;紧耦合优点是更高的布线密度,相同阻抗下走线可以更细,同时具有更好的共模噪声抑制;紧耦合缺点是阻抗随线间距的变化大;【…

单片机02_寄存器_GPIO设置__点灯

芯片概述 C51:0口、1口、2口、3口,P00~p07、P10~P17、P20~P27、P30~P37 STM32:A口、B口、C口、D口,PA0~PA15/PA5 GPIOA.5 STM32F407ZGT6有7组GPIO端口,分别是:A B C D E F G,每组均有16个GPIO端…