从FasterTransformer源码解读开始了解大模型(2.1)代码通读03

从FasterTransformer源码解读开始了解大模型(2.2)代码解读03-forward函数

写在前面的话

本篇的内容继续解读forward函数,从650行开始进行解读

零、输出Context_embeddings和context_cum_log_probs的参数和逻辑

从653行开始,会从输入的请求tensors中读取一个配置,如果请求中配置了is_return_context_embeddings参数并设置为true时,则会在返回参数中增加一个context_embeddings的tensor,这个tensor中包含的数据是输入经过了ContextDecoder过程的所有的层之后的logits,并对其进行求和。可能有些类似于强化学习(RLHF)之类的场景会用到这里的输出,所以在这里做了一层准备。

跳转到1093行,可以看见,这里调用了invokeSumLengthDimension这个kennels,来将context_decoder_output_buf这块buf中的显存数据拷贝到输出tensor的context_embeddings中,可以简单看一下这个kernel的实现,在src/fastertransformer/kernels/gpt_kernels.cu中

template<typename T>
void invokeSumLengthDimension(float*       out_buf,
                              const T*     in_buf,
                              const size_t batch_size,
                              const size_t input_length,
                              const size_t hidden_dim,
                              cudaStream_t stream)
{
    dim3 gridSize(batch_size);
    dim3 blockSize(256);

    sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}


template<typename T>
__global__ void sum_length_dimension(
    float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{
    const int bidx = blockIdx.x;

    for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {
        float accum = 0.0f;
        for (int step = 0; step < input_length; step++) {
            accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);
        }
        out_buf[bidx * hidden_dim + hidx] = accum;
    }
}

从kernel中可以看出,这个kernel任务划分为按照batch_size维度进行了grid task分配,并为每个任务划分了256个线程。在kernel内部则是将每个batch内所有输入的logits按照输入length维度进行了累加,并拷贝到输出buf中。

类似的一个设置和处理环节,在783行还处理了is_return_context_cum_log_probs,这里会将contextDecoder完成之后的输出进行cum log计算。其中主要的处理逻辑函数是403行的computeContextCumLogProbs函数,进入这个函数后,在425到502行的处理逻辑是,首先将context_decoder的输出进行layernorm,然后使用cublas矩阵乘,将词表维度的logits计算出来。在此之后,在504行,则会使用invokeLogProbFromLogits这个kernel.

template<typename T>
void invokeLogProbFromLogits(float*       cum_log_probs,
                             const T*     logits,
                             const int*   input_ids,
                             const int*   input_lengths,
                             const size_t max_input_length,
                             const size_t batch_size,
                             const size_t vocab_size,
                             const size_t vocab_size_padded,
                             void*        workspace,
                             const size_t workspace_size,
                             cudaStream_t stream,
                             const bool   batch_first)
{
    // A batched version of log prob computation.
    //
    // cum_log_probs: [batch_size]
    // logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
    // input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
    // input_lengths: [batch_size]
    // workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.

    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    // block_size should be multiple of 32 to use warpReduceMax.
    const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
    assert(block_size % 32 == 0);
    assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
    assert(vocab_size <= vocab_size_padded);

    float* log_probs = reinterpret_cast<float*>(workspace);
    int    gx        = batch_first ? batch_size : max_input_length - 1;
    int    gy        = batch_first ? max_input_length - 1 : batch_size;
    dim3   grid(gx, gy);
    log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,
                                                         logits,
                                                         input_ids,
                                                         input_lengths,
                                                         max_input_length,
                                                         batch_size,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         batch_first);
    accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(
        cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}

__global__ void accumulate_log_probs(float*       cum_log_probs,
                                     const float* log_probs,
                                     const int*   lengths,
                                     const size_t max_input_length,
                                     const size_t batch_size,
                                     const bool   batch_first)
{
    // Accumulate the log probability along with the sequence dimension.
    //   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
    //
    // cum_log_probs: [batch_size], cumulative log probability
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
    //   log probability of each token
    // lengths: [batch_size], sequence lengths
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.

    int bidx = blockIdx.x;   // batch dim
    int tidx = threadIdx.x;  // step dim

    if (bidx < batch_size) {
        int length = lengths[bidx];
        // reposition logits to data for the current batch.
        log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
        int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.
        float local_accum = 0.0f;
        for (int step = tidx; step < length - 1; step += blockDim.x) {
            local_accum += static_cast<float>(log_probs[step * stride]);
        }
        float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);
        if (tidx == 0) {
            cum_log_probs[bidx] = accum;
        }
    }
}

在任务划分阶段,grid的x维度是batch维度,而y维度则是输入长度,进入kernel后,则是按照batch维度对任务的log_probs进行了ReduceSum,并写入到cum_log_probs中。

一、跳过prompt和prefix阶段

回到653行,接下来是ft中处理其自身特殊的prompt和prefix的逻辑,但我们的源码解读可以先跳过这一段。之所以要跳过这一段是在当前主要的大模型处理逻辑中,并不会需要在推理引擎这一层过多地关注prompt和prefix的逻辑。在对话的大模型agent中,对于第n次对话的输入/输出,其真正要进行forward的输入,往往是由前n-1轮模型的输出和用户的输入共同组成的。从推理引擎的角度来看,prefill阶段在真正端到端时延的占比并不算高,所以专门设计一个处理prompt的逻辑(还需要常驻的显存空间)多少显得有些得不偿失了

二、正常的逻辑起始和输入展开

让我们直接来到865行,考虑以下一种场景来简化我们的源码解读的结构:beam_width=1,tp=1,pp=1,use_shared_contexts不启用,也就是说,没有beam search,在单机单卡上进行一次简单的,不共享contexts的推理服务的进行。

从866行开始,由于是计算起始,所以先对一些计算中的buff进行清空。在870行,由于不进行beam search,所以也不需要对beam search用到的buf进行清空。

在877到887行,如果词表大小不对齐的话,需要将词表计算的权重进行对齐拷贝。

在889到905行,不使用shared context的话,这一段的逻辑也不需要进行处理。在914行处理prompt的逻辑也可以跳过,直接进入955行的处理逻辑。在955行,由于我们的输入是带batch的,所以需要将batch进行展开(tiled),这里使用的是kernel invokeTileGptPromptInputs

void invokeTileGptPromptInputs(int*         tiled_input_ids,
                               int*         tiled_input_lengths,
                               int*         tiled_prompt_lengths,
                               const int*   input_ids,
                               const int*   input_lengths,
                               const int*   prefix_prompt_lengths,
                               const int    batch_size,
                               const int    beam_width,
                               const int    max_input_length,
                               cudaStream_t stream)
{
    dim3 grid(batch_size, beam_width);
    dim3 block(min(1024, max_input_length));
    if (prefix_prompt_lengths != nullptr) {
        tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,
                                                              tiled_input_lengths,
                                                              tiled_prompt_lengths,
                                                              input_ids,
                                                              input_lengths,
                                                              prefix_prompt_lengths,
                                                              max_input_length);
    }
    else {
        tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,
                                                               tiled_input_lengths,
                                                               tiled_prompt_lengths,
                                                               input_ids,
                                                               input_lengths,
                                                               prefix_prompt_lengths,
                                                               max_input_length);
    }
}

template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int*       tiled_input_ids,
                                    int*       tiled_input_lengths,
                                    int*       tiled_prompt_lengths,
                                    const int* input_ids,
                                    const int* input_lengths,
                                    const int* prefix_prompt_lengths,
                                    const int  max_input_length)
{
    if (threadIdx.x == 0) {
        tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];
        if (PREFIX_PROMPT) {
            tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];
        }
    }
    for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {
        tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =
            input_ids[blockIdx.x * max_input_length + index];
    }
}

在任务划分阶段,按照batch维度进行划分,每个任务起了至少1024个线程来进行拷贝。由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再处理

在kernel中,由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再按照batch维度进行处理,将一个batch*max_input_length的数据进行展开,并拷贝到对应的buf,这有利于我们进行接下来的后续embedding和MHA计算
在这里插入图片描述

下一回预告

下一回继续讲解forward函数中的处理逻辑,会简单讲解embedding, pre_layernorm,以及进入attention_layer之后的初步讲解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/781534.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

怎样让家长单独查到自己孩子的期末成绩?

期末考试的钟声已经敲响&#xff0c;随着最后一份试卷的收卷&#xff0c;学生们的紧张情绪渐渐平息。然而&#xff0c;对于老师们来说&#xff0c;这仅仅是另一个忙碌周期的开始。成绩的统计、分析、反馈&#xff0c;每一项工作都不容小觑。尤其是将成绩单一一私信给家长&#…

【Python】组合数据类型:序列,列表,元组,字典,集合

个人主页&#xff1a;【&#x1f60a;个人主页】 系列专栏&#xff1a;【❤️Python】 文章目录 前言组合数据类型序列类型序列常见的操作符列表列表操作len()append()insert()remove()index()sort()reverse()count() 元组三种序列类型的区别 集合类型四种操作符集合setfrozens…

分子AI预测赛Task4笔记(结束)

话不多说&#xff0c;直接上官方链接&#xff1a;‌​​​‍&#xfeff;​⁠​‌​‍​​&#xfeff;​‌​⁠‬​&#xfeff;‬​​‌​​​​‬‬​​​​‍⁠‍‌​&#xfeff;⁠Task3&#xff1a;进阶baseline详解 - 飞书云文档 (feishu.cn)Task4&#xff1a;持续尝试&…

RAID的实现

软RAID&#xff0c;在实际工作中使用较少&#xff0c;性能太次。 mdadm工具&#xff0c;主要在虚拟机上使用&#xff0c; 硬RAID 用一个单独的芯片&#xff0c;这个芯片的名字叫做RAID卡&#xff0c;数据在RAID中进行分散的时候&#xff0c;用的就是RAID卡。 模拟RAID-5工作…

【Transformer】transformer模型结构学习笔记

文章目录 1. transformer架构2. transformer子层解析3. transformer注意力机制4. transformer部分释疑 图1 transformer模型架构 图2 transformer主要模块简介 图3 encoder-decoder示意图N6 图4 encoder-decoder子层示意图 1. transformer架构 encoder-decoder框架是一种处理NL…

AI编程探索- iOS 实现类似苹果地图 App 中的半屏拉起效果

想要的效果 功能分析 想要实现这种效果&#xff0c;感觉有点复杂&#xff0c;于是就想搜一下相关资料看看&#xff0c;可问题是&#xff0c;我不知道如何描述这种效果&#x1f602;。 当我们遇到这种效果看着很熟悉&#xff0c;但是不知道如何描述它具体是什么的时候&#…

有一个日期(Date)类的对象和一个时间(Time)类的对象,均已指定了内容,要求一次输出其中的日期和时间

可以使用友元成员函数。在本例中除了介绍有关友元成员函数的简单应用外&#xff0c;还将用到类的提前引用声明&#xff0c;请读者注意。编写程序&#xff1a; 运行结果&#xff1a; 程序分析&#xff1a; 在一般情况下&#xff0c;两个不同的类是互不相干的。display函…

华为云OBS 通过S3客户端访问

华为云好像没有对S3协议的支持说明其实底层是支持S3协议的。 使用S3的时候我们会需要endpoint&#xff0c;桶名字&#xff0c;region&#xff0c;AWS_ACCESS_KEY,AWS_SECRET_KEY 其中endpoint 就是图片中的&#xff0c;桶名字也很容易找到&#xff0c;region 就是你的endpoint…

Nestjs基础

一、创建项目 1、创建 安装 Nest CLI&#xff08;只需要安装一次&#xff09; npm i -g nestjs/cli 进入要创建项目的目录&#xff0c;使用 Nest CLI 创建项目 nest new 项目名 运行项目 npm run start 开发环境下运行&#xff0c;自动刷新服务 npm run start:dev 2、…

Maven一键配置阿里云远程仓库,让你的项目依赖飞起来!

文章目录 引言一、为什么选择阿里云Maven仓库&#xff1f;二、如何设置Maven阿里云远程仓库&#xff1f;三、使用阿里云Maven仓库的注意事项总结 引言 在软件开发的世界里&#xff0c;Maven无疑是一个强大的项目管理工具&#xff0c;它能够帮助我们自动化构建、依赖管理和项目…

C++初学者指南-5.标准库(第一部分)--迭代器

C初学者指南-5.标准库(第一部分)–迭代器 Iterators 文章目录 C初学者指南-5.标准库(第一部分)--迭代器 Iterators1.默认正向迭代器2.反向迭代器3.基于迭代器的循环4.示例&#xff1a;交换相邻的一对元素5.迭代器范围6.迭代器范围中的元素数量7. 总结&#xff1a;迭代器 指向某…

动态规划|剑指 Offer II 093. 最长斐波那契数列

如果数组 arr 中存在三个下标 i、j、k 满足 arr[i]>arr[j]>arr[k] 且 arr[k]arr[j]arr[i]&#xff0c;则 arr[k]、arr[j] 和 arr[i] 三个元素组成一个斐波那契式子序列。由于数组 arr 严格递增&#xff0c;因此 arr[i]>arr[j]>arr[k] 等价于 i>j>k。 把这道题…

记录第一次使用air热更新golang项目

下载 go install github.com/cosmtrek/airlatest 下载时提示&#xff1a; module declares its path as: github.com/air-verse/air but was required as: github.com/cosmtrek/air 此时&#xff0c;需要在go.mod中加上这么一句&#xff1a; replace github.com/cosmtrek/air &…

jmeter-beanshell学习4-beanshell截取字符串

再写个简单点的东西&#xff0c;截取字符串&#xff0c;参数化文件统一用csv&#xff0c;然后还要用excel打开&#xff0c;如果是数字很容易格式就乱了。有同事是用双引号把数字引起来&#xff0c;报文里就不用加引号了&#xff0c;但是这样beanshell处理起来&#xff0c;好像容…

插入排序——C语言

假设我们现在有一个数组&#xff0c;对它进行排序&#xff0c;插入排序的算法如同它的名字一样&#xff0c;就是将元素一个一个插入到合适的位置&#xff0c;那么&#xff0c;该如何做呢&#xff1f; 如果我们要从小到大进行排序的话&#xff0c;步骤如下&#xff1a; 1.对于…

LabVIEW机器视觉系统中的图像畸变、校准和矫正

在机器视觉应用中&#xff0c;图像畸变、校准和矫正是确保图像准确性的关键步骤。LabVIEW作为一种强大的图像处理和分析工具&#xff0c;提供了一系列功能来处理这些问题。以下是对图像畸变、校准和矫正的详细介绍。 图像畸变 图像畸变 是指由于摄像镜头的光学特性或拍摄角度问…

二分法查找有序表的通用算法(可查链表,数组,字符串...等等)

find_binary函数 注意事项&#xff1a; &#xff08;1&#xff09;你设计的迭代器模板中必须有using value_type T&#xff0c;且有加减运算功能&#xff0c;其本上能与C标准库std中一样。 &#xff08;2&#xff09;集合必须是有序的。 下面是函数代码&#xff1a; /// &…

土豆炒肉做法

菜单&#xff1a;土豆、葱、铁辣子、纯瘦肉、淀粉、生抽、酱油、刀、案板、十三香、盐巴、擦板 流程&#xff1a; 洗土豆&#xff0c;削皮&#xff0c;擦成条&#xff0c;用凉水过滤两遍淀粉&#xff0c;顺便放个燥里洗肉&#xff0c;切成条&#xff0c;按照生抽、酱油、淀粉、…

react dangerouslySetInnerHTML将html字符串以变量方式插入页面,点击后出现编辑状态

1.插入变量 出现以下编辑状态 2.解决 给展示富文本的标签添加css样式 pointerEvents: none

JAVA之(方法的重载与重写、this关键字、super关键字)

方法的重载与重写 一、方法的重载与重写1、回顾方法的定义2、重载的概念3、重写 二、this关键字1、何为this方法2、使用方法&#xff08;1&#xff09;在构造方法中指构造器所创建的新对象&#xff08;2&#xff09; 方法中指调用该方法的对象&#xff08;3&#xff09; 在类本…