【OpenAI Triton】理解矩阵乘法中的super-grouping 21a649eddf854db5ad4c7753afb7cb72

【OpenAI Triton】理解矩阵乘法中的super-grouping

前言

最近做推理加速,会涉及一些底层算子的工作,老早就听说triton写算子比较方便,最近正好有一些应用场景,就根据官方文档和大佬们的见解记录一下自己的所学所得;

参考

  • 官方矩阵乘法示例
  • http://giantpandacv.com/project/%E9%83%A8%E7%BD%B2%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%BC%96%E8%AF%91%E5%99%A8/OpenAI%20Triton%20MLIR%20%E7%AC%AC%E4%B8%80%E7%AB%A0%20Triton%20DSL/
  • https://www.zhihu.com/question/622685131

本文主要是记录自己在理解学习时对其中一块内容的理解,并不是做复述或翻译一遍官方文档的内容。所以阅读本文前建议先根据官方文档自己跑一遍矩阵乘法的示例,对triton的功能有个大致的理解,然后再来过其中每一行的代码;如果你对cuda等比较熟悉,看完之后可能就直接秒懂,哈哈哈

L2 Cache Optimizations

原始的实现

pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;

l2 cache 优化后的实现

# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m

首先讨论为何需要进行L2 Cache优化。简单来说,GPU硬件中存在寄存器、L1 Cache、L2 Cache和全局内存等结构,它们的读写效率逐级降低。

寄存器是GPU中最快速的存储器,用于存储线程的变量和计算中间结果。每个线程都有自己的一组寄存器,能够进行快速访问。然而,寄存器的数量非常有限,通常只有几十到几百个。对于计算密集型任务,如矩阵乘法,可以利用寄存器来存储临时变量和迭代计算中的中间结果,以减少对其他内存层次的访问。

L1 Cache位于GPU SM(Streaming Multiprocessor)内部,用于存储频繁访问的数据和指令。它是一个相对较小但速度较快的缓存,用于提高数据的局部性和访问效率。L1 Cache主要用于存储线程级别的数据,如线程的寄存器溢出数据、局部变量以及线程块内共享内存的数据。

L2 Cache是位于GPU SM之上的一个更大的缓存层次。它的容量通常比L1 Cache大数倍,但速度相对较慢。L2 Cache用于存储来自多个SM的数据,并提供更大的缓存容量以提高数据的局部性和复用性。L2 Cache能够减少对全局内存的访问,从而提高数据访问效率和整体性能。

回到矩阵乘法的优化,由于它是计算密集型操作,数据传输损耗对性能影响非常严重。因此,能够利用最近的数据存储器是至关重要的。

通常情况下,矩阵乘法会按照一个矩阵块的大小进行计算。在每次计算之前,所需的数据会从全局内存加载到L2 Cache中,然后在SM执行过程中直接从L2 Cache读取和写入数据。命中率指的是计算所需的数据能否直接从L2 Cache获取,高命中率意味着可以减少对全局内存的数据获取,从而避免大量的数据传输性能损耗。

Triton与cutlass或cuda编程的区别

以我目前的浅薄理解,Triton的编程模型主要集中在块(block)级别上,即用户无需过多关注块内部的线程计算过程。而Cutlass或CUDA编程往往更注重于细粒度的线程级别编程。因此,Triton在抽象层面上更高级,可以提高开发效率,但在性能和资源控制方面可能稍显不足。

理解Row-major ordering

pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;

请添加图片描述

结合这段代码和这幅图,我们来分析row-major ordering的block循环逻辑。

在图中,可以看到矩阵A、B、C都是9x9的大小,但是要注意每个黄色格子代表一个block。如果我们设定一个BLOCK_SIZE_M x BLOCK_SIZE_N大小为64x64,那么矩阵A和B的大小都将是576x576。这也是之前所说的triton是基于block逻辑进行编程的。

在运行时,一个SM可能会同时计算多个block,而多个SM则可以并行计算更多的block。但是无论是哪个SM计算,它所需的矩阵数据都会优先从L2 Cache中获取。这与之前解释的L2缓存命中率密切相关。

pid = triton.program_id(0);

这里的program_id是一个非常重要的概念。我们编写的程序只确定了一个block的计算过程,而所有block的计算是由编译器来编译循环。这行代码实际上是在确定这个block在循环逻辑中的位置。其中的axis=0表示这个“循环”是一维的,即只有一层。如果还有axis=1,那就意味着还有嵌套的第二层。这些不同的block是并行执行的(不同的物理硬件)。

grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;

这两行比较好理解,就是计算出在block维度,行和列block的数量;

pid_m = pid / grid_n;
pid_n = pid % grid_n;

这两行代码是row-major ordering的核心逻辑,也是最简单的逻辑。在triton编程中,除了确定每个block内部的计算逻辑外,还可以根据pid(program_id)确定block的遍历逻辑,这是一个非常关键的概念。

根据之前的说明,这里的pid只有一维,范围是从0到80。在这个9x9的矩阵中,我们需要确定如何将0到80的序号填入其中,这就是所谓的block ordering逻辑。在这个例子中,我们按行遍历矩阵来确定pid → (pid_m, pid_n)的值,所以被称为row-major ordering(按行优先顺序)。

row-major ordering下的读写

这个官方解释得很清楚,我们以计算9个block为例来说明。在row-major ordering的模式下,对于矩阵A来说,需要读取9个block的数据;而对于矩阵B来说,需要读取81个block的数据;最后,矩阵C需要写入9个block的数据。因此,总共需要读取90个block的数据,写入9个block的数据。

Super-Grouping Ordering

请添加图片描述

看官方给的图,先说结论,同样在写入9个block的数据时,矩阵A和矩阵B都需要读取27个block的数据,总共涉及54个block的读取操作。相比于row-major ordering,这是一个显著的改进。

通常情况下,较高的L2缓存命中率通常意味着较少的读写次数,而较低的L2缓存命中率则通常伴随着更多的读写次数。

由于L2缓存是有限的,想象一下进行一次密集计算操作时,同时有大量的SM并行运行。如果存在大量的读写操作,无疑会对L2缓存的数据存储产生影响。当矩阵的规模很小,只需要一个指令就能完成所有数据的计算时,即所有的数据都能放到L2缓存中,L2缓存的影响就不明显。然而,在实际情况下,这种情况是不太可能的。

排布逻辑

如果我们能够完全理解row-major ordering的排布过程,那么其他的排布逻辑其实也就很容易理解了。这是因为它们的原理是相同的,都是通过pid(program_id)来确定(pid_m, pid_n)的值,即在一个9x9的block矩阵中按照希望的顺序填入pid序号。

例如,对于super-grouping的结构,它实际上是将一个block按照横向和纵向同时进行拓展,形成一个小矩形。这个小矩形看起来就像一个超级小组

在实际编程中,我们可以根据具体的需求和算法的特性,选择不同的排布逻辑来组织block的布局。无论是row-major ordering、column-major ordering还是super-grouping,它们的核心思想都是通过pid来确定每个block在整个block矩阵中的位置和顺序。

理解这些排布逻辑有助于我们更好地设计并行计算任务的数据布局,从而利用好计算资源,提高计算效率和性能。

接下来按行阐述其排布过程

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

引用董鑫大佬的两幅图(参考的第三个链接)
请添加图片描述
请添加图片描述
前三行代码逻辑是一致的,不再赘述;

num_pid_in_group = GROUP_SIZE_M * num_pid_n GROUP_SIZE_M是行方向的组大小,这里定义为3,即上面第一幅图的红色框框,num_pid_in_group 就是计算该组内一共有多少个block;

group_id = pid // num_pid_in_group 就是判断对于当前pid它是在哪个group;

first_pid_m = group_id * GROUP_SIZE_M 计算当前group第一个pid_m的编号,注意是pid_m,上面提到,排布逻辑其实就是将pid映射到(pid_m, pid_n)的过程;

group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 这一步是为了避免最后一个group是无法整除的,当前这个例子正好是整除的,所以看不太出来。稍微阐述一下,假如无法整除,设最后一个group只有2行,因为是按列排序,在算pid在这个group中对应的pid_m时,假如pid是30,那么其行号就应该是(30-27)%2=1;结合图2可以对比一下。

pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

这两行就是将pid映射到(pid_m,pid_n)的最终逻辑代码了;

一个例子

接上图2,我们对pid=30的block,来计算一下其对应的实际pid_m和pid_n。

pid = 30
num_pid_m = 9
num_pid_n = 9
GROUP_SIZE_M = 3
num_pid_in_group = 3 * 9 = 27 # 一组有27个pid
group_id = 30 // 27 = 1 # 在第1组
first_pid_m = 1 * 3 = 3 # 第一组第一个pid的行号为3
group_size_m = min(9 - 3, 3) = 3 # 不是最后一组也不是非整除,所以不影响

pid_m = 3 + (30 % 3) = 3 + 0 = 3 # 按列排序,所以取模group_size_m
pid_n = (30 % 27) // 3 = 3 // 3 = 1

pid -> (pid_m, pid_n) <==> 30 -> (3, 1) # 根据图2对比一下

至此讲完block的逻辑排布;后面可能还会再补充一些东西

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/434584.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CMake-深入理解find_package()的用法

前言&#xff1a; CMake给我们提供了find_package()命令用来查找依赖包&#xff0c;理想情况下&#xff0c;一句find_package()命令就能把一整个依赖包的头文件包含路径、库路径、库名字、版本号等情况都获取到&#xff0c;后续只管用就好了。但实际使用过程可能会出现这样那样…

#微信小程序创建(获取onenet平台数据)

1.IDE&#xff1a;微信开发者工具 2.实验&#xff1a;创建一个小程序&#xff08;http get获取onenet平台数据&#xff09; 3.记录&#xff1a; 百度网盘链接&#xff1a;https://pan.baidu.com/s/1eOd-2EnilnhPWoGUMj0fzw 提取码: 2023 &#xff08;1&#xff09;新建一个工…

【工具相关】zentao用例管理平台部署实践

文章目录 一、备份还原1、数据备份1.1、前言1.2、版本备份1.3、数据备份 2、数据恢复2.1、版本恢复2.2、数据恢复 二、问题处理1、ERROR: SQLSTATE[HY000] [2002] Connection refused 一、备份还原 1、数据备份 1.1、前言 禅道系统从10.6版本以后&#xff0c;新增数据备份设…

lv20 QT进程线程编程

知识点&#xff1a;启动进程 &#xff0c;线程 &#xff0c;线程同步互斥 1 启动进程 应用场景&#xff1a;通常在qt中打开另一个程序 process模板 QString program “/bin/ls"; QStringList arguments; arguments << "-l" << “-a";QPro…

Java进阶-IO(4)

前面几篇介绍了java IO的基础部分&#xff0c;现在进入核心内容的学习&#xff0c;如File类、动态读取和序列化等&#xff0c;如下。 一、File类 1、概述 是 java.io 包中唯一代表磁盘文件本身的对象&#xff08;可以通过 File 类操作文件和目录&#xff09;&#xff0c;定义…

【Flutter 】get-cli init报错处理

报错内容 get init 命令跳出,报错内如下 Select which type of project you want to creat Synchronous waiting using dart:cli waitFor Unhandled exceotion . Dart WaitforEvent is deprecated and disabled by default. This feature will be fully removed in Dart 3.4 …

Docker安装MySQL镜像实战分享

今天我们对Docker安装MySQL镜像进行实战分享&#xff0c;以更深入的了解容器的使用场景。我们在云付服务器Ubuntu环境上已经安装好了Docker&#xff0c;接下来我们开始安装mysql5.7版本&#xff0c;安装mysql有两种思路&#xff0c;直接拉取mysql镜像和自己做mysql镜像&#xf…

【python基础学习10课_面向对象、封装、继承、多态】

一、类与对象 1、类的定义 在类的里面&#xff0c;称之为方法。 在类的外面&#xff0c;称之为函数。类&#xff1a;人类&#xff0c;一个族群&#xff0c;是一个群体类的语法规则&#xff1a;class 自定义的类名():属性 -- 变量方法 -- 函数类&#xff0c;首字母大写&#x…

软考-中级-系统集成2023年综合知识(五)

&#x1f339;作者主页&#xff1a;青花锁 &#x1f339;简介&#xff1a;Java领域优质创作者&#x1f3c6;、Java微服务架构公号作者&#x1f604; &#x1f339;简历模板、学习资料、面试题库、技术互助 &#x1f339;文末获取联系方式 &#x1f4dd; 软考中级专栏回顾 专栏…

阿里云服务器配置选择哪个比较好?看花眼了

阿里云服务器配置怎么选择&#xff1f;CPU内存、公网带宽和系统盘怎么选择&#xff1f;个人开发者或中小企业选择轻量应用服务器、ECS经济型e实例&#xff0c;企业用户选择ECS通用算力型u1云服务器、ECS计算型c7、通用型g7云服务器&#xff0c;阿里云服务器网aliyunfuwuqi.com整…

Java常用笔试题,面试java对未来的规划

最重要的话 2021年&#xff0c;真希望行业能春暖花开。 去年由于疫情的影响&#xff0c;无数行业都受到了影响&#xff0c;互联网寒冬下&#xff0c;许多程序员被裁&#xff0c;大环境格外困难。 我被公司裁掉后&#xff0c;便着急地开始找工作&#xff0c;一次次地碰壁&#…

微信jsSDK前端签名错误,巨坑; 前后端分离的一大失误。

微信 JS 接口签名校验工具 1. 确保你后端生成的签名是正确&#xff0c;这个是第一步。否则后面都是白扯。 以用上面微信自带的验证签名工龄进行验证。 确保生成的签名和你的签名是一致的。 2. timestamp需要是字符串类型. 3. 切记&#xff0c;URL不要encode&#xff0c;如果…

9.12零钱兑换(LC518-M)(开始完全背包,与01背包的不同仅在于遍历顺序)

算法&#xff1a; 这是一道典型的背包问题&#xff0c;一看到钱币数量不限&#xff0c;就知道这是一个完全背包。 但本题和纯完全背包不一样&#xff0c;纯完全背包是凑成背包最大价值是多少&#xff0c;而本题是要求凑成总金额的物品组合个数&#xff01; 动规五步曲&#…

剑指offer--c++--n个骰子的点数

目录 题目&#xff1a; 题目分析&#xff1a; 最后编写代码&#xff1a; 输出结果 题目&#xff1a; 把n个骰子扔在地上&#xff0c;所有骰子朝上一面的点数之和为s。输入n&#xff0c;打印出s的所有可能的值出现的概率。 感谢大佬的帮助&#xff1a;https://www.cnblogs.c…

人人都写过的6个bug

大家好&#xff0c;我是知微。 程序员写bug几乎是家常便饭&#xff0c;也是我们每个人成长过程中难以避免的一部分。 为了缓解这份“尴尬”&#xff0c;今天想和大家分享一些曾经都会遇到过的bug&#xff0c;让我们一起来看看这些“经典之作”。 1、数组越界 #include <…

数据库-DDL

show databases; 查询所有数据库 select database(); 查询当前数据库 use 数据库名&#xff1b; 使用数据库 creat database[if not exists] 数据库名…

成都源聚达:开抖音店铺分数需要达到多少

在数字化浪潮中&#xff0c;抖音以其独特的平台魅力吸引了无数商家入驻。但想要开设一家抖音店铺并非随意之举&#xff0c;它需要商家达到一定的评分标准。这如同参加一场考试&#xff0c;只有成绩合格者才有资格入座。那么&#xff0c;这个分数线究竟是多少呢? 据官方数据显示…

力扣hot100:560.和为K的子数组(前缀和+哈希表)

分析&#xff1a; 这个题目乍一看&#xff0c;数据大小用暴力解法大概率会超时&#xff0c;可能想用双指针&#xff0c;但是问题出现在 可能存在负数&#xff0c;也就是说即使是找到了一个答案&#xff0c;后面也可能存在负数和正数抵消&#xff0c;又是答案&#xff0c;因此不…

【Linux】文件传输工具lrzsz的安装与使用

目录 一、关于lrzsz 二、安装lrzsz 三、lrzsz的说明及使用 1、上传命令rz 2、下载命令sz 一、关于lrzsz 在开发的过程中&#xff0c;经常遇到 需要在 Linux 和 Windows 之间上传下载文件的情况 这时&#xff0c;一般都是使用 FTP 或者 WinSCP 工具进行上传下载, 虽然也能…

用ChatGPT计算植被归一化指数NDVI并出图的详细教程

用ChatGPT结合GIS计算植被归一化指数NDVI出图教程 用ENVI计算比较繁琐&#xff0c;如今AI的盛行&#xff0c;我们可以轻松解决计算问题&#xff0c;只需1一分钟变可以出图。 详细教学请看上方视频步骤。 更多ChatGPT教学内容请见&#xff1a;ChatGPT结合GIS&#xff1a;一分钟…