【OpenAI Triton】理解矩阵乘法中的super-grouping
前言
最近做推理加速,会涉及一些底层算子的工作,老早就听说triton写算子比较方便,最近正好有一些应用场景,就根据官方文档和大佬们的见解记录一下自己的所学所得;
参考
- 官方矩阵乘法示例
- http://giantpandacv.com/project/%E9%83%A8%E7%BD%B2%E4%BC%98%E5%8C%96/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%BC%96%E8%AF%91%E5%99%A8/OpenAI%20Triton%20MLIR%20%E7%AC%AC%E4%B8%80%E7%AB%A0%20Triton%20DSL/
- https://www.zhihu.com/question/622685131
本文主要是记录自己在理解学习时对其中一块内容的理解,并不是做复述或翻译一遍官方文档的内容。所以阅读本文前建议先根据官方文档自己跑一遍矩阵乘法的示例,对triton的功能有个大致的理解,然后再来过其中每一行的代码;如果你对cuda等比较熟悉,看完之后可能就直接秒懂,哈哈哈
L2 Cache Optimizations
原始的实现
pid = triton.program_id(0); grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M; grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N; pid_m = pid / grid_n; pid_n = pid % grid_n;
l2 cache 优化后的实现
# Program ID pid = tl.program_id(axis=0) # Number of program ids along the M axis num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) # Number of programs ids along the N axis num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # Number of programs in group num_pid_in_group = GROUP_SIZE_M * num_pid_n # Id of the group this program is in group_id = pid // num_pid_in_group # Row-id of the first program in the group first_pid_m = group_id * GROUP_SIZE_M # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # *Within groups*, programs are ordered in a column-major order # Row-id of the program in the *launch grid* pid_m = first_pid_m + (pid % group_size_m) # Col-id of the program in the *launch grid* pid_n = (pid % num_pid_in_group) // group_size_m
首先讨论为何需要进行L2 Cache优化。简单来说,GPU硬件中存在寄存器、L1 Cache、L2 Cache和全局内存等结构,它们的读写效率逐级降低。
寄存器是GPU中最快速的存储器,用于存储线程的变量和计算中间结果。每个线程都有自己的一组寄存器,能够进行快速访问。然而,寄存器的数量非常有限,通常只有几十到几百个。对于计算密集型任务,如矩阵乘法,可以利用寄存器来存储临时变量和迭代计算中的中间结果,以减少对其他内存层次的访问。
L1 Cache位于GPU SM(Streaming Multiprocessor)内部,用于存储频繁访问的数据和指令。它是一个相对较小但速度较快的缓存,用于提高数据的局部性和访问效率。L1 Cache主要用于存储线程级别的数据,如线程的寄存器溢出数据、局部变量以及线程块内共享内存的数据。
L2 Cache是位于GPU SM之上的一个更大的缓存层次。它的容量通常比L1 Cache大数倍,但速度相对较慢。L2 Cache用于存储来自多个SM的数据,并提供更大的缓存容量以提高数据的局部性和复用性。L2 Cache能够减少对全局内存的访问,从而提高数据访问效率和整体性能。
回到矩阵乘法的优化,由于它是计算密集型操作,数据传输损耗对性能影响非常严重。因此,能够利用最近的数据存储器是至关重要的。
通常情况下,矩阵乘法会按照一个矩阵块的大小进行计算。在每次计算之前,所需的数据会从全局内存加载到L2 Cache中,然后在SM执行过程中直接从L2 Cache读取和写入数据。命中率指的是计算所需的数据能否直接从L2 Cache获取,高命中率意味着可以减少对全局内存的数据获取,从而避免大量的数据传输性能损耗。
Triton与cutlass或cuda编程的区别
以我目前的浅薄理解,Triton的编程模型主要集中在块(block)级别上,即用户无需过多关注块内部的线程计算过程。而Cutlass或CUDA编程往往更注重于细粒度的线程级别编程。因此,Triton在抽象层面上更高级,可以提高开发效率,但在性能和资源控制方面可能稍显不足。
理解Row-major ordering
pid = triton.program_id(0);
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;
结合这段代码和这幅图,我们来分析row-major ordering的block循环逻辑。
在图中,可以看到矩阵A、B、C都是9x9的大小,但是要注意每个黄色格子代表一个block。如果我们设定一个BLOCK_SIZE_M x BLOCK_SIZE_N大小为64x64,那么矩阵A和B的大小都将是576x576。这也是之前所说的triton是基于block逻辑进行编程的。
在运行时,一个SM可能会同时计算多个block,而多个SM则可以并行计算更多的block。但是无论是哪个SM计算,它所需的矩阵数据都会优先从L2 Cache中获取。这与之前解释的L2缓存命中率密切相关。
pid = triton.program_id(0);
这里的program_id
是一个非常重要的概念。我们编写的程序只确定了一个block的计算过程,而所有block的计算是由编译器来编译循环。这行代码实际上是在确定这个block在循环逻辑中的位置。其中的axis=0
表示这个“循环”是一维的,即只有一层。如果还有axis=1
,那就意味着还有嵌套的第二层。这些不同的block是并行执行的(不同的物理硬件)。
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M;
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N;
这两行比较好理解,就是计算出在block维度,行和列block的数量;
pid_m = pid / grid_n;
pid_n = pid % grid_n;
这两行代码是row-major ordering的核心逻辑,也是最简单的逻辑。在triton编程中,除了确定每个block内部的计算逻辑外,还可以根据pid
(program_id)确定block的遍历逻辑,这是一个非常关键的概念。
根据之前的说明,这里的pid
只有一维,范围是从0到80。在这个9x9的矩阵中,我们需要确定如何将0到80的序号填入其中,这就是所谓的block ordering逻辑。在这个例子中,我们按行遍历矩阵来确定pid → (pid_m, pid_n)的值,所以被称为row-major ordering(按行优先顺序)。
row-major ordering下的读写
这个官方解释得很清楚,我们以计算9个block为例来说明。在row-major ordering的模式下,对于矩阵A来说,需要读取9个block的数据;而对于矩阵B来说,需要读取81个block的数据;最后,矩阵C需要写入9个block的数据。因此,总共需要读取90个block的数据,写入9个block的数据。
Super-Grouping Ordering
看官方给的图,先说结论,同样在写入9个block的数据时,矩阵A和矩阵B都需要读取27个block的数据,总共涉及54个block的读取操作。相比于row-major ordering,这是一个显著的改进。
通常情况下,较高的L2缓存命中率通常意味着较少的读写次数,而较低的L2缓存命中率则通常伴随着更多的读写次数。
由于L2缓存是有限的,想象一下进行一次密集计算操作时,同时有大量的SM并行运行。如果存在大量的读写操作,无疑会对L2缓存的数据存储产生影响。当矩阵的规模很小,只需要一个指令就能完成所有数据的计算时,即所有的数据都能放到L2缓存中,L2缓存的影响就不明显。然而,在实际情况下,这种情况是不太可能的。
排布逻辑
如果我们能够完全理解row-major ordering的排布过程,那么其他的排布逻辑其实也就很容易理解了。这是因为它们的原理是相同的,都是通过pid
(program_id)来确定(pid_m, pid_n)
的值,即在一个9x9的block矩阵中按照希望的顺序填入pid
序号。
例如,对于super-grouping的结构,它实际上是将一个block按照横向和纵向同时进行拓展,形成一个小矩形。这个小矩形看起来就像一个超级小组
在实际编程中,我们可以根据具体的需求和算法的特性,选择不同的排布逻辑来组织block的布局。无论是row-major ordering、column-major ordering还是super-grouping,它们的核心思想都是通过pid
来确定每个block在整个block矩阵中的位置和顺序。
理解这些排布逻辑有助于我们更好地设计并行计算任务的数据布局,从而利用好计算资源,提高计算效率和性能。
接下来按行阐述其排布过程
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
引用董鑫大佬的两幅图(参考的第三个链接)
前三行代码逻辑是一致的,不再赘述;
num_pid_in_group = GROUP_SIZE_M * num_pid_n
GROUP_SIZE_M是行方向的组大小,这里定义为3,即上面第一幅图的红色框框,num_pid_in_group
就是计算该组内一共有多少个block;
group_id = pid // num_pid_in_group
就是判断对于当前pid它是在哪个group;
first_pid_m = group_id * GROUP_SIZE_M
计算当前group第一个pid_m的编号,注意是pid_m,上面提到,排布逻辑其实就是将pid映射到(pid_m, pid_n)的过程;
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
这一步是为了避免最后一个group是无法整除的,当前这个例子正好是整除的,所以看不太出来。稍微阐述一下,假如无法整除,设最后一个group只有2行,因为是按列排序,在算pid在这个group中对应的pid_m时,假如pid是30,那么其行号就应该是(30-27)%2=1;结合图2可以对比一下。
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
这两行就是将pid映射到(pid_m,pid_n)的最终逻辑代码了;
一个例子
接上图2,我们对pid=30的block,来计算一下其对应的实际pid_m和pid_n。
pid = 30
num_pid_m = 9
num_pid_n = 9
GROUP_SIZE_M = 3
num_pid_in_group = 3 * 9 = 27 # 一组有27个pid
group_id = 30 // 27 = 1 # 在第1组
first_pid_m = 1 * 3 = 3 # 第一组第一个pid的行号为3
group_size_m = min(9 - 3, 3) = 3 # 不是最后一组也不是非整除,所以不影响
pid_m = 3 + (30 % 3) = 3 + 0 = 3 # 按列排序,所以取模group_size_m
pid_n = (30 % 27) // 3 = 3 // 3 = 1
pid -> (pid_m, pid_n) <==> 30 -> (3, 1) # 根据图2对比一下
至此讲完block的逻辑排布;后面可能还会再补充一些东西