Llama架构及代码详解

Llama的框架图如图:
在这里插入图片描述
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:

Llama的整体组成

由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层

class Llama(torch.nn.Module):
    def __init__(self, config: ModelArgs):
        super().__init__()
       self.config = config
        # embedding层
        self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
        # RMSNorm
        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
        # n层Transformer
        self.layers = torch.nn.ModuleList()
        for i in range(self.config.n_layers):
            self.layers.append(TransformerBlock(config))


    def forward(self, tokens):
        # 进行token的嵌入编码
        h = self.tok_embeddings(tokens)
        # decoder架构需要生成一个mask
        seqlen = h.shape[1]
        mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)
        mask = torch.triu(mask, diagonal=1)
        # 进行n层Transformer
        for i in range(self.config.n_layers):
            h = self.layers[i](h, mask)
        # 进行RMSNorm
        token_embeddings = self.norm(h)
        return token_embeddings

中层
我们首先进行RMSNorm的复现

class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps):
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(dim))

    def _norm(self, tensor):
        return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, tensor):
        output = self._norm(tensor)
        return output * self.weight

然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。

class TransformerBlock(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # 多头注意力层
        self.attention = Attention(config)
        # Norm层
        self.attention_normal = RMSNorm(config.dim, config.norm_eps)
        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
        # 全连接层
        self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)

    def forward(self, embeddings, mask):
        # norm
        h = self.attention_normal(embeddings)
        # attention
        h = self.attention(h, mask)
        # add & norm
        h = self.ffn_norm(h + embeddings)
        # fnn
        f = self.ffn(h)
        # add
        return f + h

底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作

class Attention(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.n_head = config.n_heads
        self.dim = config.dim // self.n_head

        self.k = torch.nn.Linear(config.dim, config.dim)
        self.q = torch.nn.Linear(config.dim, config.dim)
        self.v = torch.nn.Linear(config.dim, config.dim)

    def forward(self, embeddings, mask):
        bsz, seq_len, dim = embeddings.shape

        k_embeddings = self.k(embeddings)
        q_embeddings = self.q(embeddings)
        v_embeddings = self.v(embeddings)
        n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
        n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
        n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)

        rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)
        rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)

        scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)
                               / math.sqrt(self.dim), dim=-1)

        n_embeddings = scores @ n_v_embeddings
        embeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)

        return embeddings
class FeedForwad(torch.nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.linear1 = torch.nn.Linear(dim, hidden_dim)
        self.linear2 = torch.nn.Linear(dim, hidden_dim)
        self.linear3 = torch.nn.Linear(hidden_dim, dim)

    def forward(self, embeddings):
        gate = torch.nn.functional.silu(self.linear1(embeddings))
        up_proj = self.linear2(embeddings) * gate
        return self.linear3(up_proj)

最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!

def compute_rotated_embedding(embedding, dim, m, base):
    # 计算所有嵌入位置的旋转角度
    all_theta = compute_all_theta(dim, m, base)
    # 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
    # 1、将嵌入投影到复数平面
    embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)
    embedding_complex_pair = torch.view_as_complex(embedding_real_pair)
    # 2、将旋转角度投影到复数平面
    all_theta = all_theta[: embedding.shape[-2]]
    theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)
    # 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
    rotated_complex_embedding = embedding_complex_pair * theta_complex_pair
    # 4、将复数平面的嵌入投影到实数平面
    rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)
    rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)
    return rotated_real_embedding

def compute_all_theta(dim, m, base):
    theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))
    m = torch.arange(0, m)
    all_theta = torch.outer(m, theta)
    return all_theta

附录:llama的config参数

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048
    use_scaled_rope: bool = True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/916109.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

冒泡选择法(c基础)

适合对象c语言初学者。 冒泡选择法 作用对一个数组进行排序。(介绍一下数组(c基础)(详细版)-CSDN博客) 核心要点 1: 数组元素个数 sz 2: 比较后的交换。 核心思路 进行(sz - 1)趟,每一趟把最大数的放到末尾。其…

深入浅出《钉钉AI》产品体验报告

1. 引言 随着人工智能技术的迅猛发展,企业协同办公领域迎来了新的变革。钉钉作为阿里巴巴集团旗下的企业级通讯与协同办公平台,推出了钉钉AI助理,旨在提高工作效率,优化用户体验。本报告将对钉钉AI助理进行全面的产品体验分析&am…

【GPTs】Gif-PT:DALL·E制作创意动图与精灵动画

博客主页: [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 💯GPTs指令💯前言💯Gif-PT主要功能适用场景优点缺点 💯小结 💯GPTs指令 中文翻译: 使用Dalle生成用户请求的精灵图动画&#…

一文看懂什么1688跨境(专业解析)

1688跨境是什么? 最近火出圈的一个新词 今天老余带大家走近1688跨境 首先为什么会出现1688跨境? 因为: 新型服务商崛起,反向海淘成趋势 在过去三年,1688涌现了一批新型的服务商: 他们帮助海外B类买家到1688采购&#xff…

SpringBoot+Vue3实现数据可视化大屏

前端工程的地址:UserManagerFront: 数据可视化前端 (gitee.com) 效果展示,可以展现出来了,样式可能还有一些丑。 后端代码 后端主要是拿到数据并对数据进行处理,按照前端需要的格式进行返回即可。 import com.njitzx.entity.Student; impor…

<项目代码>YOLOv8 手机识别<目标检测>

YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv8具有更高的…

算法定制LiteAIServer摄像机实时接入分析平台烟火识别检测算法

在公共安全领域,火灾的及时发现与处理是保障人民群众生命财产安全的重要手段。传统的烟火检测手段受限于人工巡查的局限,难以做到全天候、无遗漏的监控。然而,随着人工智能技术的飞速发展,LiteAIServer摄像机实时接入分析平台烟火…

JMeter与大模型融合应用之JMeter日志分析服务化实战应用

JMeter与大模型融合应用之JMeter日志分析服务化 引言 在当今的互联网时代,网站和应用程序的性能直接影响到用户的体验和业务的成功。为了保证系统的稳定性和高效性,性能测试成为了软件开发过程中的一个重要环节。在这其中,Apache JMeter作为一款开源的性能测试工具,凭借其…

【Pikachu】任意文件上传实战

将过去和羁绊全部丢弃,不要吝惜那为了梦想流下的泪水。 1.不安全的文件上传漏洞概述 不安全的文件上传漏洞概述 文件上传功能在web应用系统很常见,比如很多网站注册的时候需要上传头像、上传附件等等。当用户点击上传按钮后,后台会对上传的…

STM32 ADC --- 单通道采样

STM32 ADC — 单通道采样 文章目录 STM32 ADC --- 单通道采样cubeMX配置代码修改:应用 使用cubeMX生成HAL工程 需求:有多个通道需要进行ADC采样,实现每次采样只采样一个通道,且可以随时采样不同通道的功能。 cubeMX配置 这里我们…

influxDB 时序数据库安装 flux语法 restful接口 nodjsAPI

安装 Install InfluxDB | InfluxDB OSS v2 Documentation Debian和Ubuntu用户可以用apt-get包管理来安装最新版本的InfluxDB。 对于Ubuntu用户,可以用下面的命令添加InfluxDB的仓库,添加之后即可apt-get 安装influxdb2 wget -q https://repos.influx…

【轻量化】YOLOv10 更换骨干网络之 MobileNetv4 | 模块化加法!非 timm 包!

之前咱们在这个文章中讲了timm包的加法,不少同学反馈要模块化的加法,那么这篇就讲解下模块化的加法,值得注意的是,这样改加载不了mobilebnetv4官方开源的权重了~ 论文地址:https://arxiv.org/pdf/2404.10518 代码地址:https://github.com/tensorflow/models/blob/master…

电气自动控制电路图图例

单相照明双路互备自投供电电路 双路三相电源自投电路 茶炉水加热自动控制电路 简单的温度控制器电路 简易晶闸管温度自动控制电路 用双向晶闸管控制温度电路 XCT-101动圈式温度调节仪控温电路 电接点压力式温度表控温电路 TDA-8601型温度指示调节仪控温电路 XMT-DA数字…

D3 可以加载的数据格式有哪些?(12种)

D3.js 支持多种数据格式,这些格式涵盖了从简单的表格数据到复杂的地理数据。以下是一些常见的数据格式及其加载方法: D3.js 数据加载方法 d3.blob(input, init) 用途: 加载二进制数据,返回一个 Blob 对象。参数: input: 数据源 URL。init: …

Pinpoint(APM)进阶--Pinot指标采集(System Metric/Inspector)

接上文 Pinpoint使用Pinot进行指标数据存储,Pinot流摄入需要Kafka 本文详解Kafka和Pinot的安装部署,以及Pinpoint的指标采集 Pinot 简介 Apache Pinot是一个实时分布式OLAP数据存储,专为低延迟、高吞吐量分析而构建,非常适合面…

mmsegmentation: 安装 并使用自定义数据集进行训练 ·2

我用的是CHASE_DB1.py 数据集下载链接 https://staffnet.kingston.ac.uk/~ku15565/CHASE_DB1/assets/CHASEDB1.zip 这个用来转换mmseg所需要的格式 python tools/dataset_converters/chase_db1.py /path/to/CHASEDB1.zip其他数据集请看链接:https://mmsegmenta…

通义千问API调用测试 (colab-python,vue)

文章目录 代码(来自官网)colab中用python测试Qwen2.5在官网上查看并确定过期时间这里看到我的免费额度到25年5月在同一个页面,点击API示例 前端调用直接在前端调用的优缺点以vue为例(代码是基于官网node.js的代码转换而来&#xf…

【c++笔试强训】(第九篇)

目录 链表相加(⼆)(链表⾼精度加法) 题目解析 讲解算法原理 编写代码 ⼤数乘法(⾼精度乘法) 题目解析 讲解算法原理 辨析代码 链表相加(⼆)(链表⾼精度加法&#…

C#高级:使用Invoke关键字通过 Type 类型调用指定的方法

demo如下: using System.Reflection; using System;public class Program {public class Calculator{public int Add(int a, int b){return a b;}}public class Student{public string Name { get; set; }}public class Example{// 泛型方法public string Generi…

B-树特点以及插入、删除数据过程

B树(B-Tree)是一种自平衡的多路查找树,它广泛应用于数据库索引和文件系统中,尤其适用于外部存储设备(如磁盘)。B树的设计使得它能够高效地存储大量数据并支持高效的插入、删除和查询操作。以下是B树的主要特…