导出谷歌gemma模型为ONNX

参考代码如下(从GitHub - luchangli03/export_llama_to_onnx: export llama to onnx修改而来,后面会合入进去)

模型权重链接参考:

https://huggingface.co/google/gemma-2b-it

可以对modeling_gemma.py进行一些修改(transformers升级为最新版本内置该模型代码),从而提升导出的onnx性能:

1,GemmaForCausalLM中原始的logits计算为:

        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)

修改为:

        hidden_states = outputs[0]
        hidden_states = hidden_states[:,-1:,:]
        logits = self.lm_head(hidden_states)

这样使得降低prefill阶段lm_head的计算量。

2,模型使用了GemmaSdpaAttention,导出的onnx模型从一个很大的张量中索引向量仅仅用作attention mask:

causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:
    causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

这里即增加了存储又增加了计算。实际上可以直接把扩展后的attention mask作为onnx输入传入进来,从而完全消除这个存储和计算。

不知为何很多模型(例如千问等)都输入一个[1, seq_len]的向量,然后内部扩展为一个[1,1, seq_len, sumN]的mask,这些操作都可以直接替换为模型直接采用[1,1, seq_len, sumN]的mask输入。

这里对modeling_gemma.py修改方法为:

class GemmaModel(GemmaPreTrainedModel):
    def forward(
        # causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
        causal_mask = attention_mask

class GemmaSdpaAttention(GemmaAttention):
    def forward(
        # if attention_mask is not None and cache_position is not None:
        #     causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

模型导出代码(进行了上述修改,如果不想修改的话,修改下这里面的atten mask的shape,dtype即可):

import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer


class LLMForCausalLMWrapper(nn.Module):
    def __init__(self, model, config, args):
        super().__init__()
        self.model = model
        self.config = config
        self.args = args

    def forward(
        self,
        input_ids,
        attention_mask,
        position_ids,
        past_key_values,
        output_attentions=False,
        output_hidden_states=False,
        use_cache=True,
    ):
        """
        Note: you can modify modeling_gemma.py to make the converted model more efficient:
        hidden_states = outputs[0]
        hidden_states = hidden_states[:,-1:,:]
        logits = self.lm_head(hidden_states)
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=None,
            use_cache=True,
        )

        logits = outputs.logits
        kv_caches_out = []
        for past_kv in outputs.past_key_values:
            kv_caches_out.extend(past_kv)

        topk_outputs = []
        if self.args.add_topk_warper > 0:
            logging.warning("add topk to glm model")
            if self.args.topk < 0:
                raise ValueError("topk {} is invalid")
            topk_outputs = torch.topk(logits, k=self.args.topk, dim=-1)

        return logits, *kv_caches_out, *topk_outputs


def export_llm_to_single_onnx(model, config, dtype, args, model_name):
    llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)

    onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")

    layer_num = len(model.model.layers)

    hidden_size = config.hidden_size
    head_num = config.num_attention_heads
    head_dim = config.head_dim

    batch = 1
    N = 1
    sumN = 32
    lastSum = sumN - N

    input_ids_shape = [batch, N]
    input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)
    # Note: orig atten_mask shape is [1, sumN]
    attention_mask = torch.randn([batch, 1, N, sumN], dtype=dtype).to(args.device)
    position_ids = torch.ones([batch, N], dtype=torch.int64).to(args.device)

    in_names = ["input_ids", "attention_mask", "position_ids"]

    dynamic_axes = {
        'input_ids': {1: 'N', },
        'attention_mask': {2: 'N', 3: 'sumN'},
        "position_ids": {1: 'N', },
    }
    if args.dyn_batch:
        dynamic_axes['input_ids'][0] = "batch"
        dynamic_axes['attention_mask'][0] = "batch"
        dynamic_axes['position_ids'][0] = "batch"

    kv_caches_in = []
    out_names = ["lm_logits"]

    kv_cache_in_shape = [1, 1, lastSum, head_dim]
    kv_cache_dyn_axes = {2: "sumN-N"}

    if args.dyn_batch:
        kv_cache_dyn_axes[0] = "batch"

    past_key_values = []

    for i in range(layer_num):
        past_key_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)
        past_value_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)

        kv_caches_in.extend([past_key_in, past_value_in])
        in_names.extend([f"past_key_in{i}", f"past_value_in{i}"])
        out_names.extend([f"past_key{i}", f"past_value{i}"])

        dynamic_axes[f"past_key_in{i}"] = kv_cache_dyn_axes
        dynamic_axes[f"past_value_in{i}"] = kv_cache_dyn_axes

        past_key_values.append((past_key_in, past_value_in))

    input_datas = (input_ids, attention_mask, position_ids, past_key_values)

    torch.onnx.export(
        llama_model_wrapper,
        input_datas,
        onnx_file_name,
        opset_version=args.opset,
        do_constant_folding=True,
        input_names=in_names,
        output_names=out_names,
        dynamic_axes=dynamic_axes,
    )


def export_llama(args):
    device = args.device
    dtype_map = {
        "float32": torch.float32,
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
    }
    dtype = dtype_map[args.dtype]

    print(f"begin load model from {args.model_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()

    # model.model.layers = model.model.layers[:1]  # only export one layer for debug

    print(f"finish load model from {args.model_path}")
    config = model.config
    print("config:", config)

    print(f"begin export llm")
    export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='export llm',
    )
    parser.add_argument('-m', '--model_path', required=True, type=str)
    parser.add_argument('-o', '--out_dir', required=False, type=str, default="")
    parser.add_argument('--opset', required=False, type=int, default=15)
    parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")
    parser.add_argument('-p', '--dtype', required=False, type=str,
                        choices=["float32", "float16", "bfloat16"], default="float16")
    parser.add_argument('--add_topk_warper', required=False, type=int, default=0)
    parser.add_argument('--topk', required=False, type=int, default=4)
    parser.add_argument('--dyn_batch', action='store_true')

    args = parser.parse_args()
    export_llama(args)

导出的onnx文件onnxsim:

GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model

导出的onnx模型推理示例(依赖文件在GitHub - luchangli03/export_llama_to_onnx: export llama to onnx)

import numpy as np
from onnx_rt_utils import OnnxRuntimeModel, get_random_data
from sample_utils import sample_topk
from transformers import AutoTokenizer


def prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum):
    """
    only used at the first time
    in round 0, actually the lastSum is 0, thus past_key_in, past_value_in are empty tensor
    """
    for i in range(layer_num):
        past_key_in = get_random_data([1, 1, lastSum, 256], "float16")
        past_value_in = get_random_data([1, 1, lastSum, 256], "float16")

        past_key_in_name = f"past_key_in{i}"
        past_value_in_name = f"past_value_in{i}"
        glm_model_inputs[past_key_in_name] = past_key_in
        glm_model_inputs[past_value_in_name] = past_value_in
    return glm_model_inputs


def prepare_kv_cache_from_outputs(glm_model_inputs, decoder_outputs, layer_num):
    offset = 1
    for i in range(layer_num):
        past_key_in_name = f"past_key_in{i}"
        past_value_in_name = f"past_value_in{i}"

        glm_model_inputs[past_key_in_name] = decoder_outputs[offset + i * 2]
        glm_model_inputs[past_value_in_name] = decoder_outputs[offset + i * 2 + 1]
    return glm_model_inputs


def get_atten_mask(N,  sumN,  padded_len):
    attention_mask = np.zeros(shape=[N * padded_len], dtype="float16")

    pad_num = padded_len - sumN
    if (N == sumN):
        for i in range(N):
            mask_num = N - 1 - i + pad_num
            start = padded_len - mask_num
            for j in range(start, padded_len):
                attention_mask[i * padded_len + j] = -65504
    else:
        if (N != 1):
            raise ValueError("N is not 1")
        lastSum = sumN - N
        for i in range(pad_num):
            attention_mask[lastSum + i] = -65504

    attention_mask = attention_mask.reshape([N, padded_len])
    return attention_mask


# all decoder layer num
layer_num = 18
eos_token_id = 2

pt_model_path = r"E:\test_models\llama\gemma-2b-it"
onnx_model_path = "llm_onnx.onnx"

prompt = "Write me a poem about Machine Learning."
tokenizer = AutoTokenizer.from_pretrained(pt_model_path, trust_remote_code=True)
input_ids = tokenizer(prompt)['input_ids']

print(input_ids)

input_ids = np.array(input_ids).reshape([1, -1]).astype("int64")

N = input_ids.shape[1]
sumN = N
lastSum = sumN - N
print("N:", N, sumN, lastSum)

position_ids = np.arange(sumN).reshape([1, -1]).astype("int64")

input_ids = input_ids.astype("int64")
position_ids = position_ids.astype("int64")

glm_model = OnnxRuntimeModel(onnx_model_path)

max_seq = 32

glm_model_inputs = {}

gen_tokens = []

for i in range(max_seq):
    print("input_ids:", input_ids)
    print("position_ids:", position_ids)

    attention_mask = get_atten_mask(N, sumN, padded_len=sumN).astype("float16")
    print("attention_mask:", attention_mask)
    attention_mask = attention_mask.reshape([1, 1, N, sumN])

    glm_model_inputs["input_ids"] = input_ids
    glm_model_inputs["attention_mask"] = attention_mask
    glm_model_inputs["position_ids"] = position_ids

    if i == 0:
        glm_model_inputs = prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum)

    glm_model_outputs = glm_model(**glm_model_inputs)
    lm_logits = glm_model_outputs[0]
    print("lm_logits:", lm_logits)

    next_token = sample_topk(lm_logits, topk=1)
    gen_tokens.append(next_token)
    print("next_token:", next_token)

    if next_token == eos_token_id:
        break

    input_ids = np.array([next_token]).astype("int64").reshape([-1, 1])
    position_ids = np.array([sumN]).astype("int64").reshape([-1, 1])

    N = 1
    sumN += 1
    prepare_kv_cache_from_outputs(glm_model_inputs, glm_model_outputs, layer_num)

gen_text = tokenizer.decode(gen_tokens)
print("Q:", prompt)
print("A:", gen_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/441231.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

docker搭建dashdot

Dashdot 是一个指标收集工具&#xff0c;用于报告 Kubernetes 集群中的资源使用情况。假设你想要使用 Docker 来搭建 Dashdot&#xff0c;你需要制作或获取一个 Dashdot 的 Docker 镜像&#xff0c;然后可以通过 Docker CLI 命令或者使用 Docker Compose 来配置和运行这个容器。…

TinTin DESTINATION MOON|开发者不容错过的 Web3 线下活动来啦!

还记得去年 9 月 17 日的上海吗&#xff1f;「DESTINATION MOON: Web3 Dev Summit Shanghai 2023」迎来了数百名 Web3 行业爱好者的关注和参与。4 场主题演讲、3 场圆桌讨论&#xff0c;近 20 名创新者、开发者、投资人和研究员围绕公链生态、Layer2 竞争、DID、ZKP、安全等热点…

人工智能聊天机器人完整指南 - 推荐10家国外聊天机器人公司

人工智能&#xff08;AI&#xff09;聊天机器人革命正在向我们袭来。由对话式AI驱动的AI聊天机器正在改变企业世界&#xff0c;为公司提供更高效的方式与客户和员工互动。本综合指南将介绍AI聊天机器人&#xff0c;解释其主要功能和优势&#xff0c;并探讨它们如何改变您的业务…

3.8 动态规划 背包问题

一.01背包 46. 携带研究材料&#xff08;第六期模拟笔试&#xff09; (kamacoder.com) 代码随想录 (programmercarl.com) 携带研究材料: 时间限制&#xff1a;5.000S 空间限制&#xff1a;128MB 题目描述: 小明是一位科学家&#xff0c;他需要参加一场重要的国际科学大会…

使用docker安装运行rabbitmq---阿里云服务器

目录 0、阿里云没开端口的得要去安全组规则去添加&#xff1a; 1、下载RabbitMQ镜像&#xff1a; 2、查看镜像是否下载成功&#xff0c;得到docker镜像id&#xff1a; 3、运行RabbitMQ: 4、查看RabbbitMQ容器是否启动成功&#xff1a; 5、启动RabbitMQ中的插件管理 6、访…

RabbitMQ的web控制端介绍

2.1 web管理界面介绍 connections&#xff1a;无论生产者还是消费者&#xff0c;都需要与RabbitMQ建立连接后才可以完成消息的生产和消费&#xff0c;在这里可以查看连接情况channels&#xff1a;通道&#xff0c;建立连接后&#xff0c;会形成通道&#xff0c;消息的投递、获取…

Z Potentials | 星爵,他的征途不止向量数据库

纵观过去几十年的科技发展史&#xff0c;每一代新的技术架构的出现往往都伴随着新的数据范式的出现&#xff0c;也催生了多家百亿到千亿美金数据平台的诞生。如果说 2023 年科技领域的关键词是 LLM&#xff0c;那么数据库领域的关键词一定非向量数据库莫属。向量数据库是一种专…

C++面向对象程序设计-北京大学-郭炜【课程笔记(五)】

C面向对象程序设计-北京大学-郭炜【课程笔记&#xff08;五&#xff09;】 1、常量对象、常量成员函数1.1、常量对象1.2、常量成员函数1.3、常引用 2、友元&#xff08;friends&#xff09;2.1、友元函数2.2、友元类 3、运算符重载的基本概念3.1、运算符重载 4、赋值运算符的重…

红黑树的学习

红黑树 红黑树出自一种平衡的二叉查找树&#xff0c;是计算机科学中中用到的一种数据结构 1972年出现&#xff0c;当时被称之为平衡二叉B树。后来&#xff0c;1978年被修改为如今的红黑树 他是一种特殊的二叉查找树&#xff0c;红黑树的每一个节点上都有存储表示节点的颜色 …

HarmonyOS NEXT应用开发案例——自定义TabBar

介绍 本示例主要介绍了TabBar中间页面如何实现有一圈圆弧外轮廓以及TabBar页签被点击之后会改变图标显示&#xff0c;并有一小段动画效果。 效果图预览 使用说明&#xff1a; 依次点击tabBar页面&#xff0c;除了社区图标之外&#xff0c;其它图标往上移动一小段距离。 实现…

软件测试【测试用例设计】面试题详解

前言 今天笔者想和大家来聊聊测试用例&#xff0c;这篇文章主要是想要写给测试小伙伴们的&#xff0c;因为我发现还是有很多小伙伴在遇到写测试用例的时候无从下手&#xff0c;我就想和大家简单的聊聊&#xff0c;这篇文章主要是针对功能测试的。 一、微信功能测试 1.点击点…

STL之map容器代码详解

基础概念 简介&#xff1a; map中所有元素都是pair。pair中第一个元素为key&#xff08;键值&#xff09;&#xff0c;起到索引作用&#xff0c;第二个元素为value&#xff08;实值&#xff09;。所有元素都会根据元素的键值自动排序。 本质&#xff1a; map/multimap属于关…

【Android取证篇】渗透测试工具apk2url快速提取APK内的IP和URL地址

【Android取证篇】渗透测试工具apk2url快速提取APK内的IP和URL地址 通过渗透测试工具apk2url快速检索APK开发过程中没有删掉的URL地址&#xff0c;来发现一些搜索引擎、子域名查找不到的资源&#xff0c;从而进一步收集信息查找后台等—【蘇小沐】 1、实验环境 系统环境Wind…

Spring基础

spring讲义 spring官网 下文中所有项目都是通过 maven 构建的quickstart项目 csdn比较好的博客 1.什么是Spring框架 它是一个容器&#xff0c;帮助解决企业开发的难度&#xff0c;减轻对项目模块之间的管理&#xff0c;类和类之间的管理&#xff0c;帮助开发人员创建对象&a…

Linux——进程间通信

目录 进程间通信介绍 什么是进程间通信 为什么要进行进程间通信 怎么做到进程间通信 管道 管道的原理 匿名管道 pipe函数 简单线程池 管道读写的规则 命名管道 创建一个管道文件 在代码中创建管道 在代码中删除管道 命名管道实现serve与client通信 system V共享…

数组连续和 - 华为OD统一考试(C卷)

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 给定一个含有N个正整数的数组&#xff0c;求出有多少连续区间&#xff08;包括单个正整数&#xff09;&#xff0c;它们的和大于等于 x。 输入描述 第一行为两个…

掌握Python操作Word:从基础到高级全覆盖

掌握Python操作Word&#xff1a;从基础到高级全覆盖 引言Python操作Word的基础文档的创建与打开文档的基本操作 创建和打开Word文档创建新的Word文档打开现有文档读取文档内容修改现有文档 编辑文档内容添加和编辑文本设置格式插入标题 处理文档结构操作段落列表的处理表格的操…

董宇辉所有商标已转到与辉同行名下!

近日董宇辉此前由东方优选申请的所有商标已转到与辉同行主体名下&#xff0c;普推知产老杨经检索发现&#xff0c;这些商标都是2022年6月由东方优选提交申请&#xff0c;在2023年12月28时提交商标转让&#xff0c;最近转让成功&#xff0c;转让周期是2个半月左右。 转让的商标除…

Windows11下载、安装和配置JDK(包含多个版本的JDK配置)

下载JDK17 下载地址 JDK_DOWNLOAD选择JDK17版本 安装JDK17 双击打开安装包 -> 选择下一步 -> 选择安装路径&#xff08;注意不要安装在带有中文的路径下&#xff09;->修改完路径后点击下一步->安装完成。 检验安装是否成功&#xff0c;打开cmd&#xff0c;输…

C#中实现接口的一些小知识(C#用abstract或virtual来实现接口成员)

文章目录 不可用的修饰可用的修饰非抽象类实现接口抽象类实现接口抽象类与接口方法同名时一同实现 不可用的修饰 在C#中实现接口时&#xff0c;我们不能直接使用static或const来实现接口成员&#xff0c;因为接口中的成员默认都是实例成员&#xff0c;并且它们表示一种契约&am…