基于torch.compile和gptfast代码风格实现ChatGLM模型推理加速

       

目录

一、ChatGLM模型代码重构迁移

二、推理的代码重构

三、效果分析对比

参考文章


         torch2.0发布以后模型训练和推理可以实现一行代码加速,试用之后发现效果并不明显。随后gptfast项目也发布,表明它确实是可以实现模型推理的加速,看来之前试用是打开方式不对。最近参考gptfast项目,实现了对ChatGLM模型推理的加速,主要的原理是借助torch.compile对模型推理过程中构建计算图,实现加速。本文的重点工作就是展示模型代码和推理逻辑的迁移实现,以及加速效果的对比,当然这个方案比VLLM和tensort-LLM肯定是差了点,这个不是本文的重点,后面有空了也把vllm和tensort-LLM也写写博客对比一下效率。

一、ChatGLM模型代码重构迁移

      这个工作是真的不是特别好做,需要对模型结构和模型输入输出非常熟悉,同时也要对gptfast项目迁移原则比较熟悉,才能比较快的迁移成功。核心原则是不能有tensor切片操作,同时kvcache这种也要写成固定的长度,计算过程中不断的去填充更新,同时还要放在模型的结构外侧作为一个参数传入,加速才有效果。还有一个点要注意注意力计算的实现,由于torch更新了scaled_dot_product_attention使得最大长度的定长的矩阵计算注意力,和之前动态逐步增加长度的值是一样的,这个是注意力计算中tensor切片改写的前提(验证过确实是一样的)。细节的地方需要注意kvcache的维度形状,解码过程中不同阶段(首次forward和kvcache存在后的)模型输入的full_attention_mask是不一样的。

整体结构

class TransformerGLM(nn.Module):
    def __init__(self, config, device) -> None:
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        rotary_dim = (
            128
        )
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)

        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.seq_length = config.seq_length

    def forward(self, input_ids,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.BoolTensor] = None,
                input_pos=None,
                is_input_mask=False,
                kv_caches=None
                ) -> Tensor:

        inputs_embeds = self.embedding(input_ids)
        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        rotary_pos_emb = rotary_pos_emb[position_ids]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        presents = ()
        for i, layer in enumerate(self.layers):
            inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
                                            attention_mask=attention_mask, kv_cache=kv_caches[i])
            presents = presents + (kv_cache,)
        hidden_states = self.final_layernorm(inputs_embeds)
        lm_logits = self.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        return lm_logits, presents

注意模型的输入新增的有input_pos,模型解码token的位置,kv_caches;模型基本模块上没有变化,精简其中的一下预处理逻辑和分支,主要就是要让torch.compile()能完成计算图的构建。

kvcache模块

class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (2, max_batch_size, max_seq_length, 128)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: S, k_val: [S, B, H, D]
        assert input_pos.shape[0] == k_val.shape[0]
        k_out = self.k_cache

        v_out = self.v_cache
        k_val = k_val.transpose(0, 2).contiguous()

        v_val = v_val.transpose(0, 2).contiguous()
        k_out[:, :, input_pos] = k_val.clone()
        v_out[:, :, input_pos] = v_val.clone()
        k_out = k_out.transpose(0, 2).contiguous()

        v_out = v_out.transpose(0, 2).contiguous()

        return k_out, v_out

模块中各个变量的维度信息都标注好了,作用就是kv缓存载体以及更新逻辑提供一个方法。

其他模块就不一一介绍了,注意selfattention中kvcache的更新

整个模型的代码如下:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from math import gcd
from functools import reduce
import math


def find_multiple(n: int, *args: Tuple[int]) -> int:
    k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
    if n % k == 0:
        return n
    return n + k - (n % k)


class CoreAttention(torch.nn.Module):
    def __init__(self, config, layer_number):
        super(CoreAttention, self).__init__()
        self.config = config
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32

        self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        projection_size = config.kv_channels * config.num_attention_heads

        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        coeff = self.layer_number
        self.norm_factor *= coeff
        self.coeff = coeff
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
        query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
        if attention_mask is None:
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                             is_causal=True)
        else:
            attention_mask = ~attention_mask
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                             attention_mask)

        context_layer = context_layer.permute(2, 0, 1, 3)
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
        context_layer = context_layer.reshape(*new_context_layer_shape)

        return context_layer


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (2, max_batch_size, max_seq_length, 128)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: S, k_val: [S, B, H, D]
        assert input_pos.shape[0] == k_val.shape[0]
        k_out = self.k_cache

        v_out = self.v_cache
        k_val = k_val.transpose(0, 2).contiguous()

        v_val = v_val.transpose(0, 2).contiguous()
        k_out[:, :, input_pos] = k_val.clone()
        v_out[:, :, input_pos] = v_val.clone()
        k_out = k_out.transpose(0, 2).contiguous()

        v_out = v_out.transpose(0, 2).contiguous()

        return k_out, v_out


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


class TransformerGLM(nn.Module):
    def __init__(self, config, device) -> None:
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        rotary_dim = (
            128
        )
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)

        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.seq_length = config.seq_length

    def forward(self, input_ids,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.BoolTensor] = None,
                input_pos=None,
                is_input_mask=False,
                kv_caches=None
                ) -> Tensor:

        inputs_embeds = self.embedding(input_ids)
        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        rotary_pos_emb = rotary_pos_emb[position_ids]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        presents = ()
        for i, layer in enumerate(self.layers):
            inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
                                            attention_mask=attention_mask, kv_cache=kv_caches[i])
            presents = presents + (kv_cache,)
        hidden_states = self.final_layernorm(inputs_embeds)
        lm_logits = self.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        return lm_logits, presents


class MLP(torch.nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            # **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            # **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.config = config
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        # 32
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size

        self.num_multi_query_groups_per_partition = config.multi_query_group_num
        self.qkv_hidden_size = (
                self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
        )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         # device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device,
                               # **_config_to_kwargs(config)
                               )

    def forward(
            self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        (query_layer, key_layer, value_layer) = mixed_x_layer.split(
            [
                self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
            ],
            dim=-1,
        )

        query_layer = query_layer.view(
            query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        key_layer = key_layer.view(
            key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.view(
            value_layer.size()[:-1]
            + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )

        # apply relative positional encoding (rotary embedding)
        query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
        key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # 更新kvcache
        cache_k, cache_v = kv_cache
        cache_k[input_pos] = key_layer
        cache_v[input_pos] = value_layer
        key_layer = cache_k.clone()
        value_layer = cache_v.clone()
        kv_cache = (key_layer, value_layer)

        key_layer = key_layer.unsqueeze(-2)
        key_layer = key_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        key_layer = key_layer.contiguous().view(
            key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.unsqueeze(-2)
        value_layer = value_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        value_layer = value_layer.contiguous().view(
            value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask=attention_mask)

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache


class TransformerBlock(nn.Module):
    def __init__(self, config, layer_number, device) -> None:
        super().__init__()
        self.hidden_dropout = config.hidden_dropout
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                dtype=config.torch_dtype)
        self.mlp = MLP(config, device=device)

    def forward(self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None):
        layernorm_output = self.input_layernorm(hidden_states)
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            rotary_pos_emb,
            input_pos,
            attention_mask=attention_mask,
            kv_cache=kv_cache
        )
        residual = hidden_states
        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        mlp_output = self.mlp(layernorm_output)
        residual = layernorm_input
        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output
        return output, kv_cache


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


if __name__ == '__main__':
    import os

    os.environ['CUDA_VISIBLE_DEVICES'] = "1"
    from transformers import AutoConfig

    model_path = "./chatglm2-6b-merge"
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    model = TransformerGLM(config, device=None)
    for name, _ in model.named_parameters():
        print(name)

二、推理的代码重构

推理方法,也就是重写transformer模型中的generate这个方法,对于一次生成可以分为第一次解码forward阶段和余下的解码forward阶段。实现分别如下,只实现了greedy search 策略:

@torch.no_grad()
def first_decode_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
    logits, kv_caches = model(input_ids=input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=False,
                              attention_mask=attention_mask, kv_caches=kv_caches)
    logits = logits[:, -1:]
    next_tok = torch.argmax(logits, dim=-1)
    return next_tok, kv_caches


@torch.no_grad()
def decode_one_token_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
    logits, kv_caches = model(input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=True,
                              attention_mask=attention_mask, kv_caches=kv_caches)
    logits = logits[:, -1:]
    next_tok = torch.argmax(logits, dim=-1)
    return next_tok, kv_caches

主要是得到解码过程中模型输出的token和kv_caches。特别要注意的是不能把这两个方法封装到一个类中,然后再进行torch.compile这样模型能正确输出结果,但是推理速度没有提升的,也就是torch.compile并没有生效。

整体的generate逻辑,包含停止符号,模型的初始输入、kvcaches初始化以及attention_mask输入的变化、position_ids的输入变化,batch推理是padding的加入。

def generate_own_batch(model,
                 inputs,
                 sampling_kwargs,
                 eos_token,
                       max_seq_length, max_batch_size):
    device = inputs['input_ids'].device
    cache_shape = (max_seq_length, max_batch_size, 2, 128)
    dtype = torch.bfloat16
    kv_caches = [(torch.zeros(cache_shape, dtype=dtype).to(device), torch.zeros(cache_shape, dtype=dtype).to(device))
                 for _ in range(model.config.num_layers)]

    input_ids = inputs['input_ids']

    ori_input_ids = input_ids.clone()
    position_ids = inputs['position_ids']

    input_pos = []
    for _ in range(max_batch_size):
        pos = list(range(0,input_ids.shape[1]))
        input_pos.append(pos)
    input_pos = torch.tensor(input_pos, device=input_ids.device)

    # input_pos = torch.arange(0, input_ids.shape[1], device=input_ids.device)
    next_token, kv_caches = first_decode_batch(model, input_ids, position_ids, input_pos, None, kv_caches)

    full_attention_mask = torch.ones(max_batch_size, 1, 1, max_seq_length).to(device).bool()
    full_attention_mask[:, :, :, input_pos] = False

    # pading部分为true
    for i in range(full_attention_mask.shape[0]):
        for j in range(input_ids.shape[1]):
            if input_ids[i, j] == 0:
                full_attention_mask[i, :, :, j] = True

    input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
    num_new_tokens = sampling_kwargs["max_length"]
    T = input_ids.size()[1]

    position_ids = position_ids[:,-1:]

    input_pos = []
    for _ in range(max_batch_size):
        pos = [T]
        input_pos.append(pos)
    input_pos = torch.tensor(input_pos, device=next_token.device, dtype=torch.long)

    # position_ids = torch.tensor([[T - 1]], device=next_token.device, dtype=torch.long)
    # input_pos = torch.tensor([T], device=input_ids.device, dtype=torch.long)

    for i in range(num_new_tokens):
        input_pos += 1
        # Actually better for Inductor to codegen attention here
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            full_attention_mask[:, :, :, input_pos] = False
            next_token, kv_caches = decode_one_token_batch(model, next_token, position_ids, input_pos,
                                                     full_attention_mask, kv_caches)

            input_ids = torch.cat((input_ids, next_token.clone()), dim=1)

            if (input_ids == eos_token).sum(dim=1).all():
                break

            position_ids += 1
            # token = next_token.tolist()
            # token = next_token.tolist()[0]
            # generated_tokens.append(token)
    return input_ids, ori_input_ids

推理核心逻辑

    model = TransformerGLM(config=config, device=None)
    checkpoint_dir = Path(model_path)
    model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
    converted_state_dict = model.state_dict()

    gen_kwargs = {"max_length": 200, "num_beams": 1,
                  "do_sample": False, "top_p": 0.8,
                  "temperature": 0.95
                  }
    device = "cuda:0"
    model_path = "./chatglm2-6b-merge"
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    eos_token = tokenizer.eos_token_id
......
# 编译加速
    global first_decode_batch, decode_one_token_batch
    decode_one_token_batch = torch.compile(decode_one_token_batch, mode="reduce-overhead", fullgraph=True)
    first_decode_batch = torch.compile(first_decode_batch, dynamic=True, fullgraph=True)

......
generate_own_batch(model, inputs, gen_kwargs, eos_token, max_seq_length, max_batch_size)

核心点在于把解码两阶段的函数使用torch.compile函数包裹一下,真实解码过程中就会进行解码加速。

三、效果分析对比

展示一下glm模型ori、compile、和compile+int8,bs=1,max_seq_length= 1000的情况下的推理速度和效果的对比,7Bglm模型,模型输入prompt如下:

[
    "你好",
    "你是谁呀?",
    "你能做什么呀?",
    "你真厉害",
    "真棒呀",
    "再见了",
    "给我推荐一部电影",
    "你知道明天天气怎么样吗?"
]

ori原始transformer的推理效果如下:

使用compile后效果如下:

compile+int8效果如下:

可以看到相同的模型和相同的数据在bs=1下,原始模型推理速度31.7 tokens/s,compile的推理速度68.1 tokens/s,110.9 tokes/s;加速效果确实比较明显。

业务领域上的实验,这里也可以给一个结论,数据就不展示了,业务上生成的token数目每次推理大都在20 tokens以内,结果如下:

这次的分享就到这里为止了,这个迁移后的模型和推理在我们公司的服务端还有个问题,我们服务端采用的多进程异步来实现web服务的,这个gptfast的服务化的集成显示int8不生效,而且bs=1时候的推理加速并没有线下加速效果明显,具体原因一直没有弄明白,可能是其他进程占用服务器资源,导致torch.compile加速失效或者降低。

参考文章

gpt-fast实战

modeling_chatglm

gpt-fast项目源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/473129.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c/c++ 深拷贝和浅拷贝

深拷贝与浅拷贝 深拷贝(Deep Copy)和浅拷贝(Shallow Copy)是对象复制的两种不同方式,它们涉及到对象成员数据的复制方式和内存管理。 浅拷贝(Shallow Copy): 浅拷贝是指将一个对象的…

投资400亿美元!人工智能或将诞生超级大国

据外媒报道,沙特阿拉伯政府计划设立约 400 亿美元的基金来投资人工智能,如此规模的基金将成为迄今为止全球最大的专注于人工智能发展的基金之一。 据知情人士透露,该基金长期以来一直被硅谷用来为科技初创企业提供资金,甚至一度是…

在线教育话术(1W字精选)

产品结构图 Nginx实现代理 问:我们在本机的host文件中配置了域名映射,都是同一个服务器。我们只需要输入对应的域名就可以到对应的界面,这是怎么实现的? 答:主要就是通过Nginx反向代理来实现的,Nginx会先…

【go语言开发】性能分析工具pprof使用

本文主要介绍如何在项目中使用pprof工具。首先简要介绍pprof工具的作用;然后介绍pprof的应用场景,主要分为工具型应用和服务型应用。最后数据分析项目,先采集项目信息,再可视化查看 文章目录 前言应用场景工具型应用服务型应用 数…

基于补丁方式修复 nginx漏洞 缓冲区错误漏洞(CVE-2022-41741)、越界写入漏洞(CVE-2022-41742)

nginx1.22.0版本漏洞 CVE-2022-41741 、CVE-2022-41742 漏洞描述 1、nginx 缓冲区错误漏洞(CVE-2022-41741) 此插件基于版本检测,有可能误报,未开启 MP4 模块的nginx属于误报,请忽略该漏洞。Nginx是美国Nginx公司的一款轻量级Web服务器/反…

Jmeter Ultimate Thread Group 和 Stepping Thread Group

线程组:使用复杂场景的性能测试 有时候我们做性能测试时,只依靠自带的线程组,显示满足不了性能测试中比较复杂的场景,下面这两种线程组可以帮助你很好的完成复杂的场景 第一种:Stepping Thread Group 在取样器错误后…

2024年【安全员-C证】考试资料及安全员-C证新版试题

题库来源:安全生产模拟考试一点通公众号小程序 安全员-C证考试资料是安全生产模拟考试一点通生成的,安全员-C证证模拟考试题库是根据安全员-C证最新版教材汇编出安全员-C证仿真模拟考试。2024年【安全员-C证】考试资料及安全员-C证新版试题 1、【多选题…

Java基础入门day17

day17 复习二分查找java package com.saas; ​ public class BinarySearch { ​public static void main(String[] args) {int[] nums {12, 21, 33, 77, 89, 90}; ​System.out.println(binarySearch(nums, 21));} ​public static int binarySearch(int[] arrs, int target)…

springBoot项目,无配置中心,怎么实现类似功能

实现EnvironmentPostProcessor import cn.hutool.http.HttpUtil; import org.springframework.boot.SpringApplication; import org.springframework.boot.env.EnvironmentPostProcessor; import org.springframework.boot.env.YamlPropertySourceLoader; import org.springfr…

springboot企业级抽奖项目业务一(登录模块)

开发流程 该业务基于rouyi生成好了mapper和service的代码,现在需要在controller层写接口 实际操作流程: 看接口文档一>controller里定义函数一>看给出的工具类一>补全controller里的函数一>运行测试 接口文档 在登录模块有登录和登出方…

在windows上安装Jenkins

jenkins安装 下载jenkins 官网:Jenkins download and deployment 官方文档说明:Jenkins User Documentation 安装jenkins1.点击下载好的安装包,点击Next 2.选择一个安装路径 如果系统是windows家庭版打不开策略就创建一个txt文件&#xff0c…

Android分区存储到底该怎么做

文章目录 一、Android存储结构二、什么是分区存储?三、私有目录和公有目录三、存储权限和分区存储有什么关系?四、我们应该该怎么做适配?4.1、利用File进行操作4.2、使用MediaStore操作数据库 一、Android存储结构 Android存储分为内部存储和…

NBlog Java定时任务-备份MySQL数据

NBlog部署维护流程记录(持续更新):https://blog.csdn.net/qq_43349112/article/details/136129806 为了避免服务器被攻击,给博客添加了一个MySQL数据备份功能。 此功能是配合博客写的,有些方法直接用的已有的&#xf…

Matlab中inv()函数的使用

在Matlab中,inv()函数是用来求解矩阵的逆矩阵的函数。逆矩阵是一个与原矩阵相乘后得到单位矩阵的矩阵。在数学中,矩阵A的逆矩阵通常用A^-1表示。 什么是逆矩阵 在数学中,对于一个n阶方阵A,如果存在一个n阶方阵B,使得…

Gradio官方文档

文章目录 构建您的第一个demo分享您的demo进度条受密码保护的应用程序The Interface class(接口类)Components Attributes(组件属性)多个输入和输出组件图像示例嵌套列表描述性内容手风琴中的附加输入The 4 Kinds of Gradio Inter…

Android: Gradle 命令

一、查看整个项目依赖传递关系 x.x.x (*) 该依赖已经有了,将不再重复依赖。x.x.x -> x.x.x 该依赖的版本被箭头所指的版本代替。x.x.x -> x.x.x(*) 该依赖的版本被箭头所指的版本代替,并且该依赖已经有了,不再重复依赖。 1. gradlew ap…

冰岛人[天梯赛]

文章目录 题目描述思路AC代码 题目描述 输入样例 15 chris smithm adam smithm bob adamsson jack chrissson bill chrissson mike jacksson steve billsson tim mikesson april mikesdottir eric stevesson tracy timsdottir james ericsson patrick jacksson robin patrickss…

2024年最新Anaconda3 2024版中Jupyter Notebook安装

一、 Anaconda3 2024版下载 1.下载:Free Download | Anaconda 2.等待 解释:默认选择等等下载 ,时间可能数分钟 3.安装 解释:打开刚刚下载的Anaconda Navigator,并如图安装低版本,高版本会直接报错 4. …

[zdyz]FreeRTOS笔记

FreeRTOS基础知识 1,任务调度器简介 调度器就是使用相关的调度算法来决定当前需要执行的哪个任务 抢占式调度 时间片调度 协程式调度 略 2,任务状态 运行态 正在执行的任务,该任务就处于运行态,注意在STM32中,同…

【Web】记录[长城杯 2022 高校组]b4bycoffee题目复现

目录 前言 环境准备 简单分析 EXP(两种打法) 生成Payload 恶意类 ①Spring命令执行回显类 ②Filter型内存马 前言 本地jar包运行打通了,远程500,nss靶机有问题,换了bugku就可( 主要记录下做题过程,纯菜狗,小…