【llm对话系统】大模型 Llama 源码分析之并行训练方案

1. 引言

训练大型语言模型 (LLM) 需要巨大的计算资源和内存。为了高效地训练这些模型,我们需要采用各种并行策略,将计算和数据分布到多个 GPU 或设备上。Llama 作为当前最流行的开源大模型之一,其训练代码中采用了多种并行技术。本文将深入 Llama 的训练代码,分析其并行训练方案,主要关注参数并行部分结构参数共享

2. 并行训练策略概述

常见的并行训练策略包括:

  • 数据并行 (Data Parallelism, DP):将数据分成多个 batch,每个 GPU 处理一个 batch,所有 GPU 使用相同的模型副本。
  • 模型并行 (Model Parallelism, MP):将模型分成多个部分,每个 GPU 负责模型的一部分。
  • 流水线并行 (Pipeline Parallelism, PP):将模型的不同层分配到不同的 GPU 上,形成一个流水线。
  • 张量并行 (Tensor Parallelism, TP):将模型的张量 (例如,权重矩阵) 分片到多个 GPU 上。
  • 序列并行 (Sequence Parallelism, SP): 将序列长度分片到多个 GPU 上。

Llama 主要采用了数据并行张量并行,以及一些结构参数共享的优化。

3. Llama 中的参数并行方案

Llama 使用了 ZeRO (Zero Redundancy Optimizer) 技术,这是一种强大的内存优化方法,它结合了数据并行和模型并行。ZeRO 的核心思想是将模型状态 (权重、梯度和优化器状态) 分片到多个 GPU 上,从而减少每个 GPU 的内存占用。

ZeRO 有三个阶段:

  • ZeRO-1 (Optimizer State Partitioning):将优化器状态 (例如,Adam 的动量和方差) 分片。
  • ZeRO-2 (Gradient Partitioning):在 ZeRO-1 的基础上,将梯度也分片。
  • ZeRO-3 (Parameter Partitioning):在 ZeRO-2 的基础上,将模型参数也分片。

Llama 主要使用了 ZeRO-3,将模型参数、梯度和优化器状态都分片到多个 GPU 上。

3.1 参数分片计算

在 Llama 的训练代码中, 以 torch.distributed.fsdp 库为例 (Fully Sharded Data Parallel, FSDP),它实现了 ZeRO-3 的功能。

以下是一个简化的 FSDP 参数分片示例:

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import functools

# 假设我们有一个简单的 Transformer 模型
class TransformerLayer(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(hidden_dim, 4 * hidden_dim)
        self.linear2 = torch.nn.Linear(4 * hidden_dim, hidden_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.nn.functional.relu(x)
        x = self.linear2(x)
        return x

# 初始化分布式环境
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")

# 模型和优化器
hidden_dim = 768
model = TransformerLayer(hidden_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 使用 FSDP 包装模型
# 使用自动包装策略
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={TransformerLayer,},
)
model = FSDP(model, fsdp_auto_wrap_policy=auto_wrap_policy, device_id=device)

# 模拟训练数据
x = torch.randn(1, 10, hidden_dim, device=device)

# 前向传播
y = model(x)

# 反向传播
loss = y.sum()
loss.backward()

# 优化器更新
optimizer.step()

# 清空梯度
optimizer.zero_grad()

print(f"Rank {rank}: 训练完成")

代码解释:

  1. 初始化分布式环境dist.init_process_group("nccl") 初始化分布式训练环境,使用 NCCL 后端。
  2. 模型和优化器:定义一个简单的 Transformer 层和 Adam 优化器。
  3. FSDP 包装:使用 FullyShardedDataParallel 包装模型。这会将模型参数分片到多个 GPU 上。
    • transformer_auto_wrap_policy 会自动将模型的每一层都用 FSDP 包装起来。
  4. 前向和反向传播:执行模型的前向和反向传播。
  5. 优化器更新:执行优化器的 step 方法。
  6. 清空梯度:清空梯度以进行下一次迭代。

运行方式:

你需要使用 torchrun(或 torch.distributed.launch)来启动这个脚本,例如:

torchrun --nproc_per_node=4 fsdp_example.py

这将使用 4 个 GPU 来训练模型。

原理说明:

  • 参数分片:在 model = FSDP(model, ...) 这一步,模型参数被分片到 4 个 GPU 上。每个 GPU 只存储一部分参数。
  • All-gather:在前向传播过程中,当需要完整的参数进行计算时(例如,矩阵乘法),FSDP 会自动执行 all-gather 操作,将所有 GPU 上的参数片段收集起来,组成完整的参数。
  • Reduce-scatter:在反向传播过程中,梯度也是分片的。计算完梯度后,FSDP 会执行 reduce-scatter 操作,将每个参数的梯度片段 reduce 到对应的 GPU 上。
  • 优化器更新:每个 GPU 使用自己分片到的参数和梯度来更新优化器状态。

通过这种方式,FSDP 显著减少了每个 GPU 的内存占用,使得训练更大的模型成为可能。

4. Llama 中的部分结构参数共享

除了参数分片,Llama 还采用了一些结构参数共享的优化,以进一步减少内存占用和提高训练效率。

例如在 Transformer 的多头注意力 (Multi-Head Attention) 机制中,不同 head 的 query, key, value 矩阵的计算通常是独立的。Llama 通过共享 key 和 value 矩阵,减少了参数量和计算量。更具体地说,llama使用了分组注意力机制(Grouped-Query Attention)。

4.1 分组注意力 (Grouped-Query Attention, GQA)

GQA 介于标准的多头注意力 (MHA) 和 Multi-Query Attention (MQA) 之间。

  • MHA: 每个 head 都有独立的 Q, K, V 矩阵。
  • MQA: 所有 head 共享 K, V 矩阵,只有 Q 矩阵是独立的。
  • GQA: 将 head 分成多个组,每个组内的 head 共享 K, V 矩阵。

例如:

假设我们有 8 个 head,可以将它们分成 4 个组,每个组 2 个 head。这样,我们就只有 4 个 K 矩阵和 4 个 V 矩阵,而不是 8 个。

代码示例 (简化版)

import torch
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_groups):
        super().__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = embed_dim // num_heads
        self.group_size = num_heads // num_groups

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        # 共享 K, V 矩阵
        self.k_proj = nn.Linear(embed_dim, num_groups * self.head_dim)
        self.v_proj = nn.Linear(embed_dim, num_groups * self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # 计算 Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_groups, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_groups, self.head_dim)

        # 将 head 分组
        q = q.view(batch_size, seq_len, self.num_groups, self.group_size, self.head_dim)

        # 计算注意力
        attn_scores = torch.einsum("bqlgd,bklhd->bqgkh", q, k) / (self.head_dim ** 0.5)
        attn_probs = attn_scores.softmax(dim=-1)
        attn_output = torch.einsum("bqgkh,bklhd->bqlgd", attn_probs, v)

        # 拼接 head
        attn_output = attn_output.reshape(batch_size, seq_len, embed_dim)

        # 输出投影
        output = self.out_proj(attn_output)

        return output

# 示例
embed_dim = 768
num_heads = 8
num_groups = 4
model = GroupedQueryAttention(embed_dim, num_heads, num_groups)
x = torch.randn(1, 10, embed_dim)
y = model(x)
print(y.shape)

代码解释:

  1. num_groups:将 head 分成多少个组。
  2. k_projv_proj:只输出 num_groups 个 head 的 K 和 V 矩阵。
  3. q 分成 num_groups 个组,每个组 group_size 个 head。
  4. 计算注意力时,每个组内的 head 共享 K 和 V 矩阵。

GQA 的优势:

  • 减少参数量:K 和 V 矩阵的数量减少了。
  • 减少计算量:计算注意力时,每个 head 需要处理的 K 和 V 数量减少了。
  • 性能接近 MHA:实验表明,GQA 的性能接近 MHA,明显优于 MQA。

llama的GQA实现在llama/model.py文件中,class Attention(nn.Module) 类下的forward函数中,更具体地,体现在self.num_headsself.num_kv_heads的参数上, 分别控制querykvhead数量,num_kv_heads小于num_heads

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962322.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)

Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

【基于SprintBoot+Mybatis+Mysql】电脑商城项目之用户注册

🧸安清h:个人主页 🎥个人专栏:【计算机网络】【Mybatis篇】 🚦作者简介:一个有趣爱睡觉的intp,期待和更多人分享自己所学知识的真诚大学生。 目录 🎯项目基本介绍 🚦项…

视频多模态模型——视频版ViT

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细解读多模态论文《ViViT: A Video Vision Transformer》,2021由google 提出用于视频处理的视觉 Transformer 模型,在视频多模态领域有…

DeepSeek本地部署(windows)

一、下载并安装Ollama 1.下载Ollama Ollama官网:Ollama 点击"Download",会跳转至下载页面。 点击"Download for Windows"。会跳转Github进行下载,如下载速度过慢,可在浏览器安装GitHub加速插件。 2.安装Ollama 双击下载的安装文件,点击"Inst…

1 HDFS

1 HDFS 1. HDFS概述2. HDFS架构3. HDFS的特性4. HDFS 的命令行使用5. hdfs的高级使用命令6. HDFS 的 block 块和副本机制6.1 抽象为block块的好处6.2 块缓存6.3 hdfs的文件权限验证6.4 hdfs的副本因子 7. HDFS 文件写入过程(非常重要)7.1 网络拓扑概念7.…

全国31省空间权重矩阵(地理相邻空间、公路铁路地理距离空间、经济空间)权重矩阵数据-社科数据

中国31个省份空间权重矩阵-社科数据https://download.csdn.net/download/paofuluolijiang/90028597 https://download.csdn.net/download/paofuluolijiang/90028597 空间权重矩阵是反映个体在空间中依赖关系的矩阵,本数据计算全国31个省三种标准化处理的空间权重矩…

Flask框架基础入门教程_ezflaskapp

pip install flaskFlask 快速入门小应用 学东西,得先知道我们用这个东西,能做出来一个什么东西。 一个最小的基于flask 的应用可能看上去像下面这个样子: from flask import Flask app Flask(__name__)app.route(/) def hello_world():ret…

机器学习笔记——特征工程

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本笔记介绍机器学习中常见的特征工程方法、正则化方法和简要介绍强化学习。 文章目录 特征工程(Fzeature Engineering)1. 特征提取&#xff…

Cursor火出圈,未来程序员还有出路吗?

大家好,我是凡人。 今天我表弟家邻居的阿姨,托他问问我目前程序员还有前景吗,希望我根据十几年的经验给出点建议,看看程序员这条路未来能不能走。 一下子不知道该怎么回复他了,如果是三年前问我,肯定毫不…

如何移植ftp服务器到arm板子?

很多厂家提供的sdk,一般都不自带ftp服务器功能, 需要要发人员自己移植ftp服务器程序。 本文手把手教大家如何移植ftp server到arm板子。 环境 sdk:复旦微 Buildroot 2018.02.31. 解压 $ mkdir ~/vsftpd $ cp vsftpd-3.0.2.tar.gz ~/vs…

第5章 公共事件

HarmonyOS通过公共事件服务为应用程序提供订阅、发布、退订公共事件的能力。 5.1 公共事件概述 在应用里面,往往会有事件。比如,朋友给我手机发了一条信息,未读信息会在手机的通知栏给出提示。 5.1.1 公共事件的分类 公共事件&#xff08…

STM32 对射式红外传感器配置

这次用的是STM32F103的开发板(这里面的exti.c文件没有how to use this driver 配置说明) 对射式红外传感器 由一个红外发光二极管和NPN光电三极管组成,M3固定安装孔,有输出状态指示灯,输出高电平灯灭,输出…

【数据结构】(2)时间、空间复杂度

一、衡量算法好坏的指标 时间复杂度衡量算法的运行速度,空间复杂度衡量算法所需的额外空间。这些指标,是某场景中选择使用哪种数据结构和算法的依据。如今,计算机的存储器已经变得容易获得,所以不再太关注空间复杂度。 二、渐进表…

FBX SDK的使用:基础知识

Windows环境配置 FBX SDK安装后,目录下有三个文件夹: include 头文件lib 编译的二进制库,根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库,需要在配置属性->C/C->预…

Ansible自动化运维实战--yaml的使用和配置(7/8)

文章目录 一、YAML 基本语法1.1. 缩进1.2. 注释1.3. 列表1.4. 字典 二、Ansible 中 YAML 的应用2.1. Ansible 剧本(Playbooks)2.2. 变量定义2.3. 角色(Roles)2.4. Inventory 文件2.5. 数据类型2.6. 引用变量 在 Ansible 里&#x…

springboot集成钉钉,发送钉钉日报

目录 1.说明 2.示例 3.总结 1.说明 学习地图 - 钉钉开放平台 在钉钉开放文档中可以查看有关日志相关的api,主要用到以下几个api: ①获取模板详情 ②获取用户发送日志的概要信息 ③获取日志接收人员列表 ④创建日志 发送日志时需要根据模板规定日志…

Node.js下载安装及环境配置教程 (详细版)

Node.js:是一个基于 Chrome V8 引擎的 JavaScript 运行时,用于构建可扩展的网络应用程序。Node.js 使用事件驱动、非阻塞 I/O 模型,使其非常适合构建实时应用程序。 Node.js 提供了一种轻量、高效、可扩展的方式来构建网络应用程序&#xff0…

ProfiNet转CANopen应用于汽车总装生产线输送设备ProfiNet与草棚CANopen质量检测系统

ProfiNet转CANopen协议转换网关模块,广泛应用于汽车行业。可替代NT 100-RE-CO和AB7658/7307产品功能 项目概述 在汽车总装生产线的末尾环节,汽车总装生产线输送设备起着关键的搬运作用,其基于 ProfiNet 协议运行,精准控制车辆在各…

「全网最细 + 实战源码案例」设计模式——桥接模式

核心思想 桥接模式(Bridge Pattern)是一种结构型设计模式,将抽象部分与其实现部分分离,使它们可以独立变化。降低代码耦合度,避免类爆炸,提高代码的可扩展性。 结构 1. Implementation(实现类…

Attention--人工智能领域的核心技术

1. Attention 的全称与基本概念 在人工智能(Artificial Intelligence,AI)领域,Attention 机制的全称是 Attention Mechanism(注意力机制)。它是一种能够动态分配计算资源,使模型在处理输入数据…