通过 AI Edge Torch 生成式 API 在设备上使用自定义大语言模型

67bb3e711571e6c5f0fb404ddf23cff9.png

作者 / 首席工程师 Cormac Brick,软件工程师 Haoliang Zhang

我们很高兴地发布 AI Edge Torch 生成式 API,它能将开发者用 PyTorch 编写的高性能大语言模型 (LLM) 部署至 TensorFlow Lite (TFLite) 运行时,从而无缝地将新的设备端生成式 AI 模型部署到边缘设备上。本文是 Google AI Edge 博客连载的第二篇。上一篇文章为大家介绍了 Google AI Edge Torch,该产品可以在使用 TFLite 运行时的设备上实现高性能的 PyTorch 模型推理。

AI Edge Torch 生成式 API 使开发者能够在设备上引入强大的新功能,例如摘要生成、内容生成等。我们之前已经通过 MediaPipe LLM Inference API 让开发者们能够将一些最受欢迎的 LLM 部署到设备上。现在,我们很高兴能进一步拓展对模型的支持范围,并让大家部署到设备,而且具备优秀的性能表现。今天发布的 AI Edge Torch 生成式 API 是初始版本,提供以下功能:

  • 简单易用的模型创作 API,支持自定义 Transformer。

  • 在 CPU 上性能表现出色,并即将支持 GPU 和 NPU。

  • 作为 AI Edge Torch 的扩展,支持 PyTorch。

  • 完全兼容现有的 TFLite 部署流程,包括量化和运行时。

  • 支持 TinyLlama、Phi-2 和 Gemma 2B 等模型。

  • 兼容 TFLite 运行时和 Mediapipe LLM 运行时接口,支持 Android、iOS 和 Web。

  • MediaPipe LLM Inference API

    https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference

  • AI Edge Torch

    https://ai.google.dev/edge/lite/models/convert_pytorch

我们将在本文中为大家深入介绍该 API 的性能、可移植性、创作开发体验、端到端推理流水线和调试工具链。更具体的文档和示例请查看:

https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/examples

性能表现

为了让 MediaPipe LLM Inference API 顺利支持最受欢迎的一些 LLM,我们的团队手工打造了几款在设备上拥有最佳性能的 Transformer 模型。通过这项工作,我们确定了几个主要课题: 如何有效地表示注意力机制、量化的使用以及良好的 KV 缓存。生成式 API 很好地完成了这些课题 (本文后面会具体提到),而且依然能达到之前手写版本性能的 90% 以上,并大大提高开发速度。

  • 通过 MediaPipe 和 TensorFlow Lite 在设备上运行大语言模型

    https://developers.googleblog.com/en/large-language-models-on-device-with-mediapipe-and-tensorflow-lite/

下表显示了三种模型样本的关键基准测试结果:

d692b53e324cba281ba0e0ca04eec04f.png

  • 三种模型样本

    https://github.com/google-ai-edge/ai-edge-torch/blob/main/ai_edge_torch/generative/examples/README.md

这些基准测试是在大核上运行,使用 4 个 CPU 线程,并且使用了这些模型在所列设备上目前所知最快的 CPU 实现。

创作体验

核心创作库提供了常见 Transformer 模型 (仅编码器、仅解码器或编码-解码器等样式) 的基本构建模块。您可以用它从头开始创作模型,或重新创作现有模型以提高性能。我们建议大多数用户采用重新创作的方式,因为这样就不需要训练 / 微调的步骤了。使用生成式 API 创作的核心优势如下:

  • 一组针对可转换性、性能和平台可移植性进行了优化的核心 Transformer 构建模块,可以轻松与常规 PyTorch 算子进行混合和匹配。

  • 一个简单的权重重映射机制。

  • 直观的量化 API。

  • 支持多签名导出,包括预填充、解码或自定义签名,并能无缝接入现成的 MP 任务 / LLM Inference API。

作为示例,下面展示如何使用新的生成式 API 以约 50 行 Python 代码重新创作 TinyLLama (1.1B) 的核心功能。

  • TinyLLama (1.1B)

    https://github.com/jzhang38/TinyLlama

步骤 1: 定义模型结构

import torch
import torch.nn as nn


from ai_edge_torch.generative.layers.attention import TransformerBlock
import ai_edge_torch.generative.layers.attention_utils as attn_utils
import ai_edge_torch.generative.layers.builder as builder
import ai_edge_torch.generative.layers.model_config as cfg




class TinyLLamma(nn.Module):


  def __init__(self, config: cfg.ModelConfig):
    super().__init__()


    self.config = config
    # Construct model layers.
    self.lm_head = nn.Linear(
        config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
    )
    self.tok_embedding = nn.Embedding(
        config.vocab_size, config.embedding_dim, padding_idx=0
    )
    self.transformer_blocks = nn.ModuleList(
        TransformerBlock(config) for _ in range(config.num_layers)
    )
    self.final_norm = builder.build_norm(
        config.embedding_dim,
        config.final_norm_config,
    )
    self.rope_cache = attn_utils.build_rope_cache(
        size=config.kv_cache_max,
        dim=int(config.attn_config.rotary_percentage * config.head_dim),
        base=10_000,
        condense_ratio=1,
        dtype=torch.float32,
        device=torch.device("cpu"),
    )
    self.mask_cache = attn_utils.build_causal_mask_cache(
        size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
    )
    self.config = config

步骤 2: 定义模型的前向函数

@torch.inference_mode
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
    B, T = idx.size()
    cos, sin = self.rope_cache
    cos = cos.index_select(0, input_pos)
    sin = sin.index_select(0, input_pos)
    mask = self.mask_cache.index_select(2, input_pos)
    mask = mask[:, :, :, : self.config.kv_cache_max]


    # forward the model itself
    x = self.tok_embedding(idx)  # token embeddings of shape (b, t, n_embd)


    for i, block in enumerate(self.transformer_blocks):
      x = block(x, (cos, sin), mask, input_pos)


    x = self.final_norm(x)
    res = self.lm_head(x)  # (b, t, vocab_size)
    return res

步骤 3: 映射旧模型权重

您可以使用库中的 ModelLoader API 轻松映射权重,就像这样:

import ai_edge_torch.generative.utilities.loader as loading_utils




# This map will associate old tensor names with the new model.
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
    ff_up_proj="model.layers.{}.mlp.up_proj",
    ff_down_proj="model.layers.{}.mlp.down_proj",
    ff_gate_proj="model.layers.{}.mlp.gate_proj",
    attn_query_proj="model.layers.{}.self_attn.q_proj",
    attn_key_proj="model.layers.{}.self_attn.k_proj",
    attn_value_proj="model.layers.{}.self_attn.v_proj",
    attn_output_proj="model.layers.{}.self_attn.o_proj",
    pre_attn_norm="model.layers.{}.input_layernorm",
    pre_ff_norm="model.layers.{}.post_attention_layernorm",
    embedding="model.embed_tokens",
    final_norm="model.norm",
    lm_head="lm_head",
)

完成这些步骤后,您可以运行一些示例输入来验证重新创作过的模型的数值正确性。如果数值检查达标,您就可以继续进行后续的转换和量化操作。

  • 验证重新创作的模型

    https://github.com/google-ai-edge/ai-edge-torch/blob/59946008def0ab867c2f4cd8931eaf607ac0d768/ai_edge_torch/generative/test/test_model_conversion.py#L132

转换和量化

通过 ai_edge_torch 提供的转换 API,您可以将 (重新创作的) Transformer 模型转换为高度优化的 TensorFlow Lite 模型。转换过程包含以下关键步骤:

  1. 导出到 StableHLO。通过 torch dynamo 编译器对 PyTorch 模型进行追踪和编译,生成带有 Aten 算子的 FX 计算图,然后由 ai_edge_torch 将其降为 StableHLO 计算图。

  2. ai_edge_torch 在 StableHLO 上执行进一步的编译器操作,包括算子融合 / 折叠等,生成高性能的 TFLite flatbuffer (包含用于 SDPA、KVCache 的融合算子)。

  • StableHLO

    https://github.com/openxla/stablehlo

量化

核心生成式 API 库还提供了一组量化 API,涵盖了常见的 LLM 量化模式。这些模式作为额外参数传递给 ai_edge_torch 转换器 API,由该 API 自动完成量化。我们会在未来的版本中提供更多的量化模式。

多签名导出

我们发现在实际推理场景中,LLM 模型需要有明确分离 (细分) 的推理函数 (预填充、解码),才能实现最佳的服务性能。这部分基于这样的观察: 预填充 / 解码可能需要采用不同的 tensor 形状,预填充受到算力限制,而解码则受到内存限制。对于大型 LLM,避免在预填充 / 解码之间重复模型权重至关重要。我们使用 TFLite 和 ai_edge_torch 中现有的多签名特性来实现这一点,使得开发者能轻松地为模型定义多个入口,如下所示:

def convert_tiny_llama_to_tflite(
    prefill_seq_len: int = 512,
    kv_cache_max_len: int = 1024,
    quantize: bool = True,
):
  pytorch_model = tiny_llama.build_model(kv_cache_max_len=kv_cache_max_len)
  
  # Tensors used to trace the model graph during conversion.
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
  prefill_input_pos = torch.arange(0, prefill_seq_len)
  decode_token = torch.tensor([[0]], dtype=torch.long)
  decode_input_pos = torch.tensor([0], dtype=torch.int64)


  # Set up Quantization for model.
  quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
  
  edge_model = (
      ai_edge_torch.signature(
          'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
      )
      .signature('decode', pytorch_model, (decode_token, decode_input_pos))
      .convert(quant_config=quant_config)
  )
  edge_model.export(f'/tmp/tiny_llama_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')

针对 LLM 的性能优化

我们在性能调查阶段发现了几个改善 LLM 性能的关键要素:

  1. 高性能的 SDPA 和 KVCache: 我们发现,如果没有足够的编译器优化 / 融合,转换后的 TFLite 模型会因为这些函数中算子的粒度问题,性能不会很好。为了解决这个问题,我们引入了高级函数边界和 StableHLO 复合算子。

  2. 利用 TFLite 的 XNNPack 代理进一步加速 SDPA: 确保大量 MatMul / 矩阵-向量计算得到很好的优化至关重要。XNNPack 库能在广泛的移动 CPU 上以出色的性能完成这些基础计算。

  3. 避免不必要的计算: 静态形状模型如果在预填充阶段有长且固定的输入消息大小,或者在解码阶段有大的固定序列长度,则带来的计算量会大于该模型需要的最小计算量。

  4. 运行时内存消耗: 我们在 TFLite 的 XNNPack 代理中引入了权重缓存 / 预打包机制,显著降低了内存的峰值使用量。

  • SDPA

    https://github.com/google-ai-edge/ai-edge-torch/blob/7f52f70709bc12cf041b3b1fd4a49bc0d52c889a/ai_edge_torch/generative/layers/attention.py#L74

部署

LLM 推理通常涉及许多预处理 / 后处理步骤和复杂的编排,例如令牌化、采样和自回归解码逻辑。为此,我们提供了基于 MediaPipe 的解决方案以及一个纯 C++ 推理示例。

  • 基于 MediaPipe 的解决方案

    https://ai.google.dev/edge/mediapipe/solutions/genai/llm_inference

  • 纯 C++ 推理示例

    https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/examples/c%2B%2B

使用 MediaPipe LLM Inference API

MediaPipe LLM Inference API 是一个高级 API,支持使用 prompt-in / prompt-out 接口进行 LLM 推理。它负责处理底层所有的 LLM 复杂流水线操作,让模型得以更轻松和顺畅地部署。要使用 MediaPipe LLM Inference API 进行部署,您需要使用给定的预填充和解码签名来转换模型,并创建一个任务包,如下方代码所示:

def bundle_tinyllama_q8():
  output_file = "PATH/tinyllama_q8_seq1024_kv1280.task"
  tflite_model = "PATH/tinyllama_prefill_decode_hlfb_quant.tflite"
  tokenizer_model = "PATH/tokenizer.model"
  config = llm_bundler.BundleConfig(
      tflite_model=tflite_model,
      tokenizer_model=tokenizer_model,
      start_token="<s>",
      stop_tokens=["</s>"],
      output_filename=output_file,
      enable_bytes_to_unicode_mapping=False,
  )
  llm_bundler.create_bundle(config)

在 TFLite 运行时使用纯 C++ 推理

我们还提供了一个简单易用的 C++ 示例 (无需 MediaPipe 依赖),来展示如何运行端到端的文本生成。如果您需要将导出的模型与自己独有的生产流水线和需求进行集成,这个示例是一个很好的起点,来帮助您实现更好的定制和灵活性。

跨平台支持

由于核心推理运行时都支持 TFLite,所以整个流水线都可以轻松集成到您的 Android (包括在 Google Play 中) 或 iOS 应用中,无需进行任何修改。这意味着用新的生成式 API 转换的模型只需添加几个自定义算子依赖即可立即部署。在未来的版本中,我们将为 Android 和 iOS 带来 GPU 支持,并支持 ML 加速器 (TPU、NPU)。

工具

最近发布的模型探索器 (Model Explorer) 是一款很好用的工具,可用于可视化诸如 Gemma 2B 之类的大型模型。分层查看和并排比较可以让您轻松查看和比较原始、重新创作和转换后的模型。我们也准备了专门的文章为您进一步介绍该工具,以及如何通过可视化基准信息来优化模型性能。

  • 模型探索器

    https://ai.google.dev/edge/model-explorer

  • 模型探索器: 大模型开发的计算图可视化工具

    https://research.google/blog/model-explorer/

以下是我们在编写 PyTorch TinyLlama 模型时使用该工具的示例。我们并排显示了 PyTorch export() 模型与 TFLite 模型。通过使用模型探索器,我们可以轻松比较每个层级 (如 RMSNorms、SelfAttention) 的表达情况。

d6a3aaba051b7068070714decc35c648.gif

△ 并排比较 TinyLlama PyTorch 和转换后的 TFLite

总结以及下一步

AI Edge Torch 生成式 API 是为 MediaPipe LLM Inference API 预构建优化模型的强大补充,适用于希望在设备上运行自己的生成式 AI 模型的开发者。我们会在接下来的几个月继续带来更新,包括 Web 支持、更好的量化和对 CPU 之外的硬件的支持。我们也会尝试探索更好的框架集成方案。

目前发布的是开发库的早期预览版本,该版本依然处于实验阶段,旨在与开发者社区进行开放互动。API 可能会发生变化,且存在不完善之处,并且对量化和模型的支持有限。但我们已经在 GitHub repo 中为大家提供了很多用于上手的内容,欢迎大家测试和体验,并随时和我们分享 PR、问题和功能需求。

  • GitHub repo

    https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative

在本次连载的第三篇文章中,我们将深入探讨模型探索器可视化工具,了解该工具如何帮助开发者们可视化、调试和探索模型。

  • 模型探索器

    https://ai.google.dev/edge/model-explorer


5e13c25da8e2c39d5e536bc00ad4d718.png

9a02db4d383a5f417307aa954cf086eb.png

a3f99636cb6d86070191d849b1b92855.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/695953.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[大模型]Gemma-2B-Instruct FastApi 部署调用

环境准备 在 平台中租赁一个 3090 等 24G 显存的显卡机器&#xff0c;如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1。 接下来打开刚刚租用服务器的 JupyterLab&#xff0c;并且打开其中的终端开始环境配置、模型下载和运行演示。 pip 换源加速下载…

[qt] qt程序打包以及docker镜像打包

目录 一 环境准备: 1.1 qt环境 1.2 linuxdeplouqt打包工具 二 qt包发布: 2.1 搜索链接库 2.2 应用程序APP打包 2.3 发布 三 docker镜像包发布 3.1 环境准备 3.2 镜像生产脚本 3.3 加载镜像并运行docker容器 一 环境准备: qt环境linuxdeployqt打包工具docker环境 1…

Python学习打卡:day01

day1 笔记来源于&#xff1a;黑马程序员python教程&#xff0c;8天python从入门到精通&#xff0c;学python看这套就够了 1、Python 软件&#xff08;PyCharm&#xff09; 安装&#xff1a;在 Linux 环境下安装 Pycharm 插件&#xff1a;汉化、翻译 设置字体大小 常用快捷…

【MySQL】(基础篇五) —— 排序检索数据

排序检索数据 本章将讲授如何使用SELECT语句的ORDER BY子句&#xff0c;根据需要排序检索出的数据。 排序数据 还是使用上一节中的例子,查询employees表中的last_name字段 SELECT last_name FROM employees;输出结果&#xff1a; 发现其输出并没有特定的顺序。其实&#xf…

【Linux】进程3——PID/PPID,父进程,子进程

在讲父子进程之前&#xff0c;我们接着上面那篇继续讲 1.查看进程 mycode.c makefile 我们在zs_108直接编译mycode.c&#xff0c;直接运行&#xff0c;然后我们转换另一个账号来查看这个进程 我们可以通过ps指令来查看进程 我们就会好奇了&#xff0c;第二行是什么&#xff…

牛客热题:矩阵的最小路径和

&#x1f4df;作者主页&#xff1a;慢热的陕西人 &#x1f334;专栏链接&#xff1a;力扣刷题日记 &#x1f4e3;欢迎各位大佬&#x1f44d;点赞&#x1f525;关注&#x1f693;收藏&#xff0c;&#x1f349;留言 文章目录 牛客热题&#xff1a;矩阵的最小路径和题目链接方法一…

[数据集][目标检测]变电站火灾检测电力场景烟雾明火检测数据集VOC+YOLO格式140张2类别真实场景非PS合成

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;140 标注数量(xml文件个数)&#xff1a;140 标注数量(txt文件个数)&#xff1a;140 标注类别…

模型 SCAMPER创新法则

说明&#xff1a;系列文章 分享 模型&#xff0c;了解更多&#x1f449; 模型_思维模型目录。激发创新的七步思维法。 1 SCAMPER创新法则的应用 1.1 SCAMPER应用之改进自行车设计 一家自行车制造商希望改进其自行车设计&#xff0c;以吸引更多的消费者并提高市场份额。他们决…

Python chardet库:字符编码检测

更多Python学习内容&#xff1a;ipengtao.com 在处理文本文件时&#xff0c;字符编码问题常常会导致乱码和错误。Python的chardet库是一个功能强大的字符编码检测工具&#xff0c;能够帮助开发者自动检测文本的编码方式&#xff0c;从而正确地读取和处理文本文件。本文将详细介…

⌈ 传知代码 ⌋ 【CLIP】文本也能和图像配对

&#x1f49b;前情提要&#x1f49b; 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间&#xff0c;对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…

LLM Algorithms(1): Flash Attention

目录 Background Flash Attention Flash Attention Algorithm 参考 NIPS-2022: Flash Attention: Fast and Memory-Efficient Exact Attention with IO-Awareness idea&#xff1a;减少资源消耗&#xff0c;提升或保持模型性能。普通attention的空间复杂度是 --》降低到F…

【PR2019】怎样批量添加转场效果及修改默认持续时间

一&#xff0c;设置“交叉溶解”效果到所有素材 选择效果&#xff0c;右击“将所选过渡设置为默认过渡”&#xff1a; 框选所有素材&#xff0c;“Ctrl D”&#xff1a; 每个素材中间有有了交叉溶解的效果&#xff1a; 二&#xff0c;修改效果属性 2.1&#xff0c;单个修…

1.nginx介绍

介绍 是一个高性能的http和反向代理服务器。 特点 占用内存少&#xff0c;并发能力强。 nginx专为性能优化而开发&#xff0c;性能是其最重要的考量&#xff0c;实现上非常注重效率&#xff0c;能经受高负载的考验&#xff0c;有报告表明能支持高达50,000个并发连接数。 基…

拐点已至:企业如何借助AI重塑增长?

2024年的激进增长与AI数智化创新并行&#xff0c;传统策略的功效已经减弱。在这篇文章中&#xff0c;我们将展望并深度探索2024年的6大创新增长策略&#xff0c;包括AI驱动的实验&#xff0c;产品再造&#xff0c;超个性化&#xff0c;自动化运营&#xff0c;短视频和KOL营销等…

力扣hot100: 48. 旋转图像

LeetCode&#xff1a;48. 旋转图像 受到力扣hot100&#xff1a;54. 螺旋矩阵的启发&#xff0c;我们可以对旋转图像按层旋转&#xff0c;我们只需要记录四个顶点&#xff0c;并且本题是一个方阵&#xff0c;四个顶点就能完成图像的旋转操作。 1、逐层旋转 注意到&#xff0…

Java核心: JarIndex的使用

在讲解Java类加载器的时候&#xff0c;我们发现URLClassLoader加载类或资源时通过访问ClassPath下的每一个路径&#xff0c;来确定类是否存在的&#xff0c;假设我们执行的命令是这样的 java -classpath D:\DiveInSpring\target\classes;C:\lib\spring-expression.jar;C:\lib\…

扩展学习|风险管理的文献综述汇总(持续更新向)

一、风险管理发展历程和趋势综述&#xff08;2007年发表&#xff09; 文献来源&#xff1a;[1]严复海,党星,颜文虎.风险管理发展历程和趋势综述[J].管理现代化, 2007(2):4.DOI:CNKI:SUN:GLXX.0.2007-02-009. 简介&#xff1a;该文以风险管理发展历程中的大事件为线索, 对风险管…

第1回 最开始的两行代码

当你按下开机键的那一刻,在主板上提前写死的固件程序BIOS会将硬盘启动区中的512(B)的数据,原封不动地复制到内存中的0x7c00这个位置,并跳转到那个位置: 下面我们针对每一步做详细介绍. 开机后初始化指向BIOS CPU中有一个PC寄存器,里面存储这将要执行的指令在内存中的地…

挑战绝对不可能:再证有长度不同的射线

黄小宁 一空间坐标系中有公共汽车A&#xff0c;A中各座位到司机处的距离h是随着座位的不同而不同的变数&#xff0c;例如5号座位到司机处的距离是h3&#xff0c;…h5&#xff0c;…。A移动了一段距离变为汽车B≌A&#xff0c;B中5号座位到司机处的距离h’h3&#xff0c;…h’h5…

C语言详解文件操作

目录 什么是文件&#xff1f; 为什么使用文件&#xff1f; 程序文件和数据文件、文本文件和二进制文件 1.程序文件和数据文件 1.1程序文件 1.2数据文件 2.文本文件和二进制文件 文件的打开和关闭&#xff08;流、标准流、文件指针和文件的打开与关闭&#xff09; 1.流和标…