【AI编译器】triton学习:矩阵乘优化

Matrix Multiplication¶

主要内容:

块级矩阵乘法

多维指针算术

重新编排程序以提升L2缓存命

自动性能调整

Motivations¶

矩阵乘法是当今高性能计算系统的一个关键组件,在大多数情况下被用于构建硬件。由于该操作特别复杂,因此通常由软件提供商来进行实现,而不是使用者自己手动编写代码。这些库称为“内核库”(例如cuBLAS),可能会存在一定的版权限制,并且通常无法随意修改以适用于深度学习工作负载中的特殊需求(例如融合式活动函数)。本教程将带你了解如何使用 Triton 自行编写稳定、可扩展的矩阵乘法函数,以实现高性能计算系统所需的功能。

大概来说,我们将要写的内核实现下面这种固定步长算法,对一个(M、K)矩阵和一个(K、N)矩阵进行乘法运算:

# Do in parallel
for m in range(0, M, BLOCK_SIZE_M):
 # Do in parallel
 for n in range(0, N, BLOCK_SIZE_N):
   acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32)
   for k in range(0, K, BLOCK_SIZE_K):
     a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
     b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
     acc += dot(a, b)
   C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc

在每个嵌套程序的每次递归过程中,Triton 实例都会产生一个专门的计算机程序来执行。

Compute Kernel¶

上面算法的实现在Triton上其实并不困难。只是内部迭代过程中,对于要从A和B缓存区读取数据的记忆位置计算这一步骤会有些复杂。为了实现该操作,我们需要使用多维指针运算方式。

指针数学运算符号¶

如果 X 是一个行向矩阵,那么X[i,j]的存储位址为 &X[i,j] = X + istride_xi + jstride_xj。因此,可以将A[m:m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]和 B[k:k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N]的块中存储地址表示成以下形式的虚拟代码:

&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] =  a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] =  b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1);

这意味着在Triton软件中,我们可以将A和B两个向量的指针分段设置为0(即i=0)。且需要注意的是,当M与数据握的大小BLOCK_SIZE_M不是相匹配的时候,我们可以通过添加一个额外模式来处理这种情况,例如,在数据中往底部加上一些无用的值。此后我们将为K维度使用遮蔽式载入操作进行处理。

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak)
b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn)

然后在内部循环中进行以下更新:

a_ptrs += BLOCK_SIZE_K * stride_ak;
b_ptrs += BLOCK_SIZE_K * stride_bk;

二级缓存的优化措施

如上所述,每个程序实例都会计算出一组长度为[BLOCK_SIZE_M,BLOCK_SIZE_N]的C语言代码块。这种方式非常重要,因为执行顺序可能导致该程序中L2缓存的命中率不同,而且令人遗憾的是,如果我们使用矩阵增量式顺序执行代码,则其性能将会受到影响。

pid = tl.program_id(axis=0)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // grid_n
pid_n = pid % grid_n

这根本不够用。

解决这个问题的一种方法是在排列块时采用数据重复使用的最佳组合。 我们可以“将所有 GROUP_M 行都集成到同一块中”,然后再按照对应的列进行排序:

# Program ID
pid = tl.program_id(axis=0)
# Number of program ids along the M axis
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
# Number of programs ids along the N axis
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# Number of programs in group
num_pid_in_group = GROUP_SIZE_M * num_pid_n
# Id of the group this program is in
group_id = pid // num_pid_in_group
# Row-id of the first program in the group
first_pid_m = group_id * GROUP_SIZE_M
# If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# *Within groups*, programs are ordered in a column-major order
# Row-id of the program in the *launch grid*
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
# Col-id of the program in the *launch grid*
pid_n = (pid % num_pid_in_group) // group_size_m

以下的矩阵乘法是每个矩阵块为九个,共有九个矩阵乘法操作。通过比较如果我们按行向量序排列输出的话,则需要在SRAM中加载90个元素来计算第一层的9个输出值,而若是以固定单元格为基础进行分组操作,只需加载54个元素。



实际上,这会使我们矩阵乘法的算法执行效率提高超过10%(例如:A100在220至245TFLOPS之间)。

Final Result¶

import torch

import triton
import triton.language as tl


def is_cuda():
   return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_hip_mi200():
   target = triton.runtime.driver.active.get_current_target()
   return target.backend == 'hip' and target.arch == 'gfx90a'


def get_cuda_autotune_config():
   return [
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                     num_warps=8),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                     num_warps=2),
       triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                     num_warps=2),
       # Good config for fp8 inputs.
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                     num_warps=8),
       triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                     num_warps=8),
       triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4),
       triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                     num_warps=4)
   ]


def get_hip_autotune_config():
   return [
       triton.Config(
           {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
           num_warps=4, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
           num_warps=8, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
           num_warps=8, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
           num_warps=4, num_stages=0),
       triton.Config(
           {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
           num_warps=4, num_stages=0),
   ]


def get_autotune_config():
   if is_cuda():
       return get_cuda_autotune_config()
   else:
       return get_hip_autotune_config()


# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
   configs=get_autotune_config(),
   key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
       # Pointers to matrices
       a_ptr, b_ptr, c_ptr,
       # Matrix dimensions
       M, N, K,
       # The stride variables represent how much to increase the ptr by when moving by 1
       # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
       # by to get the element one row down (A has M rows).
       stride_am, stride_ak,  #
       stride_bk, stride_bn,  #
       stride_cm, stride_cn,
       # Meta-parameters
       BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
       GROUP_SIZE_M: tl.constexpr,  #
       ACTIVATION: tl.constexpr  #
):
   """Kernel for computing the matmul C = A x B.
   A has shape (M, K), B has shape (K, N) and C has shape (M, N)
   """
   # -----------------------------------------------------------
   # Map program ids `pid` to the block of C it should compute.
   # This is done in a grouped ordering to promote L2 data reuse.
   # See above `L2 Cache Optimizations` section for details.
   pid = tl.program_id(axis=0)
   num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
   num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
   num_pid_in_group = GROUP_SIZE_M * num_pid_n
   group_id = pid // num_pid_in_group
   first_pid_m = group_id * GROUP_SIZE_M
   group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
   pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
   pid_n = (pid % num_pid_in_group) // group_size_m

   # ----------------------------------------------------------
   # Create pointers for the first blocks of A and B.
   # We will advance this pointer as we move in the K direction
   # and accumulate
   # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
   # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
   # See above `Pointer Arithmetic` section for details
   offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
   offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
   offs_k = tl.arange(0, BLOCK_SIZE_K)
   a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
   b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

   # -----------------------------------------------------------
   # Iterate to compute a block of the C matrix.
   # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
   # of fp32 values for higher accuracy.
   # `accumulator` will be converted back to fp16 after the loop.
   accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
   for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
       # Load the next block of A and B, generate a mask by checking the K dimension.
       # If it is out of bounds, set it to 0.
       a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
       b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
       # We accumulate along the K dimension.
       accumulator = tl.dot(a, b, accumulator)
       # Advance the ptrs to the next K block.
       a_ptrs += BLOCK_SIZE_K * stride_ak
       b_ptrs += BLOCK_SIZE_K * stride_bk
   # You can fuse arbitrary activation functions here
   # while the accumulator is still in FP32!
   if ACTIVATION == "leaky_relu":
       accumulator = leaky_relu(accumulator)
   c = accumulator.to(tl.float16)

   # -----------------------------------------------------------
   # Write back the block of the output matrix C with masks.
   offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
   offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
   c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
   c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
   tl.store(c_ptrs, c, mask=c_mask)


# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
   return tl.where(x >= 0, x, 0.01 * x)

我们现在可以创建一个操作简单的函数包装器,具有两个输入张量参数(), 该函数:(1)检查输入张量是否符合任何形状要求;(2)分配输出;(3)调用上面所示的核心计算。

def matmul(a, b, activation=""):
   # Check constraints.
   assert a.shape[1] == b.shape[0], "Incompatible dimensions"
   assert a.is_contiguous(), "Matrix A must be contiguous"
   M, K = a.shape
   K, N = b.shape
   # Allocates output.
   c = torch.empty((M, N), device=a.device, dtype=torch.float16)
   # 1D launch kernel where each block gets its own program.
   grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
   matmul_kernel[grid](
       a, b, c,  #
       M, N, K,  #
       a.stride(0), a.stride(1),  #
       b.stride(0), b.stride(1),  #
       c.stride(0), c.stride(1),  #
       ACTIVATION=activation  #
   )
   return c

Unit Test¶

我们可以对自定义的矩阵乘法运算与原生执行操作的Torch进行测试(即cuBLAS)。

torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD MI200 devices.
# MI200 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2 if is_hip_mi200() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
   print("✅ Triton and Torch match")
else:
   print("❌ Triton and Torch differ")

TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
   torch.manual_seed(0)
   a = torch.randn((512, 512), device="cuda", dtype=torch.float16)
   b = torch.randn((512, 512), device="cuda", dtype=torch.float16)
   a = a.to(torch.float8_e5m2)
   # pre-transpose b for efficiency.
   b = b.T
   b = b.to(torch.float8_e5m2)
   triton_output = matmul(a, b)
   torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
   print(f"triton_output_with_fp8_inputs={triton_output}")
   print(f"torch_output_with_fp8_inputs={torch_output}")
   if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
       print("✅ Triton and Torch match")
   else:
       print("❌ Triton and Torch differ")
triton_output_with_fp16_inputs=tensor([[-10.9531,  -4.7109,  15.6953,  ..., -28.4062,   4.3320, -26.4219],
       [ 26.8438,  10.0469,  -5.4297,  ..., -11.2969,  -8.5312,  30.7500],
       [-13.2578,  15.8516,  18.0781,  ..., -21.7656,  -8.6406,  10.2031],
       ...,
       [ 40.2812,  18.6094, -25.6094,  ...,  -2.7598,  -3.2441,  41.0000],
       [ -6.1211, -16.8281,   4.4844,  ..., -21.0312,  24.7031,  15.0234],
       [-17.0938, -19.0000,  -0.3831,  ...,  21.5469, -30.2344, -13.2188]],
      device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-10.9531,  -4.7109,  15.6953,  ..., -28.4062,   4.3320, -26.4219],
       [ 26.8438,  10.0469,  -5.4297,  ..., -11.2969,  -8.5312,  30.7500],
       [-13.2578,  15.8516,  18.0781,  ..., -21.7656,  -8.6406,  10.2031],
       ...,
       [ 40.2812,  18.6094, -25.6094,  ...,  -2.7598,  -3.2441,  41.0000],
       [ -6.1211, -16.8281,   4.4844,  ..., -21.0312,  24.7031,  15.0234],
       [-17.0938, -19.0000,  -0.3831,  ...,  21.5469, -30.2344, -13.2188]],
      device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
triton_output_with_fp8_inputs=tensor([[-21.4375,  13.1719,   6.0352,  ...,  28.7031,   8.6719, -40.7500],
       [ 10.0000,  37.0000,  -5.5664,  ...,  20.9844,  46.8125,  30.8281],
       [ 19.5625,  -3.0078, -20.0469,  ...,  -2.1309,  -8.0625,  12.5625],
       ...,
       [-18.1562, -34.1562, -27.4219,  ..., -27.3906, -24.0938, -12.3516],
       [ -3.3945,  -8.6250, -23.6562,  ...,  -4.1094,  -3.5332, -16.0781],
       [-23.9688,  -3.2637, -33.6875,  ...,  17.3125, -36.6250,  25.8594]],
      device='cuda:0', dtype=torch.float16)
torch_output_with_fp8_inputs=tensor([[-21.4375,  13.1719,   6.0352,  ...,  28.7031,   8.6719, -40.7500],
       [ 10.0000,  37.0000,  -5.5664,  ...,  20.9844,  46.8125,  30.8281],
       [ 19.5625,  -3.0078, -20.0469,  ...,  -2.1309,  -8.0625,  12.5625],
       ...,
       [-18.1562, -34.1562, -27.4219,  ..., -27.3906, -24.0938, -12.3516],
       [ -3.3945,  -8.6250, -23.6562,  ...,  -4.1094,  -3.5332, -16.0781],
       [-23.9688,  -3.2637, -33.6875,  ...,  17.3125, -36.6250,  25.8594]],
      device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match

Benchmark¶

性能指标¶

我们可以对现有的内核与cuBLAS或rocBLAS进行比较,这里以对方阵为例,但也能够根据你自定义需求对其他矩阵形状进行性能测试。

ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'

configs = []
for fp8_inputs in [False, True]:
   if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
       continue
   configs.append(
       triton.testing.Benchmark(
           x_names=["M", "N", "K"],  # Argument names to use as an x-axis for the plot
           x_vals=[128 * i for i in range(2, 33)],  # Different possible values for `x_name`
           line_arg="provider",  # Argument name whose value corresponds to a different line in the plot
           # Possible values for `line_arg`
           # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
           line_vals=["triton"] if fp8_inputs else [ref_lib.lower(), "triton"],  # Label name for the lines
           line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"],  # Line styles
           styles=[("green", "-"), ("blue", "-")],
           ylabel="TFLOPS",  # Label name for the y-axis
           plot_name="matmul-performance-" +
           ("fp16" if not fp8_inputs else "fp8"),  # Name for the plot, used also as a file name for saving the plot.
           args={"fp8_inputs": fp8_inputs},
       ))


@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
   a = torch.randn((M, K), device='cuda', dtype=torch.float16)
   b = torch.randn((K, N), device='cuda', dtype=torch.float16)
   if TORCH_HAS_FP8 and fp8_inputs:
       a = a.to(torch.float8_e5m2)
       b = b.T
       b = b.to(torch.float8_e5m2)
   quantiles = [0.5, 0.2, 0.8]
   if provider == ref_lib.lower():
       ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
   if provider == 'triton':
       ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
   perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
   return perf(ms), perf(max_ms), perf(min_ms)


benchmark.run(show_plots=True, print_data=True)




matmul-performance-fp16:
        M       N       K      cuBLAS      Triton
0    256.0   256.0   256.0    4.096000    4.096000
1    384.0   384.0   384.0   12.288000   12.288000
2    512.0   512.0   512.0   26.214401   26.214401
3    640.0   640.0   640.0   42.666665   42.666665
4    768.0   768.0   768.0   63.195428   68.056616
5    896.0   896.0   896.0   78.051553   93.661869
6   1024.0  1024.0  1024.0  110.376426   99.864382
7   1152.0  1152.0  1152.0  135.726544  129.825388
8   1280.0  1280.0  1280.0  163.840004  163.840004
9   1408.0  1408.0  1408.0  155.765024  132.970149
10  1536.0  1536.0  1536.0  176.947204  157.286398
11  1664.0  1664.0  1664.0  179.978245  176.449258
12  1792.0  1792.0  1792.0  172.914215  204.353162
13  1920.0  1920.0  1920.0  200.347822  168.585369
14  2048.0  2048.0  2048.0  226.719125  190.650180
15  2176.0  2176.0  2176.0  211.827867  211.827867
16  2304.0  2304.0  2304.0  229.691080  228.592087
17  2432.0  2432.0  2432.0  203.583068  199.251522
18  2560.0  2560.0  2560.0  224.438347  218.453323
19  2688.0  2688.0  2688.0  200.704002  198.602388
20  2816.0  2816.0  2816.0  211.719459  210.696652
21  2944.0  2944.0  2944.0  218.579083  220.513412
22  3072.0  3072.0  3072.0  208.173173  208.173173
23  3200.0  3200.0  3200.0  214.046818  219.178074
24  3328.0  3328.0  3328.0  208.067338  208.973281
25  3456.0  3456.0  3456.0  217.308808  219.080343
26  3584.0  3584.0  3584.0  216.142772  211.565625
27  3712.0  3712.0  3712.0  209.428397  213.000737
28  3840.0  3840.0  3840.0  210.250955  205.179974
29  3968.0  3968.0  3968.0  213.142249  215.971570
30  4096.0  4096.0  4096.0  221.847481  215.784121
matmul-performance-fp8:
        M       N       K      Triton
0    256.0   256.0   256.0    3.276800
1    384.0   384.0   384.0   10.053818
2    512.0   512.0   512.0   20.164923
3    640.0   640.0   640.0   34.133334
4    768.0   768.0   768.0   42.130286
5    896.0   896.0   896.0   58.538665
6   1024.0  1024.0  1024.0   61.680940
7   1152.0  1152.0  1152.0   80.702267
8   1280.0  1280.0  1280.0  102.400003
9   1408.0  1408.0  1408.0   82.602666
10  1536.0  1536.0  1536.0   99.688560
11  1664.0  1664.0  1664.0  116.868992
12  1792.0  1792.0  1792.0  135.414749
13  1920.0  1920.0  1920.0  100.905113
14  2048.0  2048.0  2048.0  114.912434
15  2176.0  2176.0  2176.0  121.226797
16  2304.0  2304.0  2304.0  134.201527
17  2432.0  2432.0  2432.0  134.423269
18  2560.0  2560.0  2560.0  146.941707
19  2688.0  2688.0  2688.0  118.171514
20  2816.0  2816.0  2816.0  129.036114
21  2944.0  2944.0  2944.0  139.988852
22  3072.0  3072.0  3072.0  144.446699
23  3200.0  3200.0  3200.0  139.433550
24  3328.0  3328.0  3328.0  131.131689
25  3456.0  3456.0  3456.0  139.725414
26  3584.0  3584.0  3584.0  149.113421
27  3712.0  3712.0  3712.0  142.506914
28  3840.0  3840.0  3840.0  138.413021
29  3968.0  3968.0  3968.0  147.194128
30  4096.0  4096.0  4096.0  156.430916

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/744569.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

fail2ban自动屏蔽之jumpserver

fail2ban是一款实用软件&#xff0c;可以监视你的系统日志&#xff0c;然后匹配日志的错误信息&#xff08;正则式匹配&#xff09;执行相应的屏蔽动作。 jumpserver是一款开源堡垒机&#xff0c;其拥有一定的防护登录&#xff0c;也可以做登录限制&#xff0c;但是相对于防火…

湖南(用户画像)源点调研 适用于新产品开发的市场调研方法

湖南&#xff08;上市验证调研&#xff09;源点咨询认为&#xff1a;其实市场与用户研究的方法不管都什么花哨的名头&#xff0c;本质上只有两种&#xff1a;定量与定性。而对于新产品的开发最重要的就是掌握好定性的研究方法。 问&#xff1a;对于新产品开发我们面对的是什么…

js如何使得四舍五入的百分比之和为100%

在JavaScript中&#xff0c;如果你想要确保一组四舍五入后的百分比之和严格等于100%&#xff0c;那么你不能直接对每个百分比进行四舍五入&#xff0c;因为四舍五入会引入误差。但是&#xff0c;你可以采用一种策略&#xff0c;即先对所有的百分比进行常规的四舍五入&#xff0…

SNEC天合储能秀:全球首发多元场景一站式工商业储能融合解决方案

6月13日-15日&#xff0c;SNEC2024光伏与智慧能源展在上海隆重举行&#xff0c;来自全球95个国家和地区3000家国内外展商齐聚展会&#xff0c;5000行业专家共话产业发展。致力于成为全球光储智慧能源解决方案的领导者&#xff0c;天合光能&#xff08;展位号&#xff1a;7.2H-E…

线性和二次判别分析

线性判别分析 线性判别分析&#xff08;Linear Discriminant Analysis&#xff0c;LDA&#xff09;亦称 Fisher 判别分析。其基本思想是&#xff1a;将训练样本投影到低维超平面上&#xff0c;使得同类的样例尽可能近&#xff0c;不同类的样例尽可能远。在对新样本进行分类时&…

测试行业,你的未来路在何方?失业之外,暗藏的这个危机更可怕!

目前测试行业现状 近期飞书的大规模裁员&#xff0c;无疑为2024年伊始蒙上了一层阴影。再加上“共享员工”模式的兴起&#xff0c;对于身处互联网行业的从业者来说&#xff0c;无疑是雪上加霜。 此外&#xff0c;延续了2023年的情况&#xff0c;在求职平台如BOSS直聘上&#…

基于Java的宠物领养管理系统【附源码】

摘 要 近些年来&#xff0c;随着科技的飞速发展&#xff0c;互联网的普及逐渐延伸到各行各业中&#xff0c;给人们生活带来了十分的便利&#xff0c;宠物管理系统利用计算机网络实现信息化管理&#xff0c;使整个宠物领养的发展和服务水平有显著提升。 本文拟采用IDEA开发工具…

《编译原理》阅读笔记:p19-p24

《编译原理》学习第 4 天&#xff0c;p19-p24总结&#xff0c;总计 5 页。 一、技术总结 1.grouping of phases 这里谈到分组(group)&#xff0c;那么就会有一个疑问&#xff0c;分组的依据是什么&#xff1f;即根据什么来分组。 (1) front end & back end 编译器包含…

第三十一篇——大数据1:从四个特征把握大数据的本质

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 大数据的特征&#xff0c;如果我们没有一个清晰的边界以及明确的定位&…

如何找到合适的Python第三方库?

找合适的Python库其实很简单&#xff0c;按照以下三步法&#xff0c;你能找到90%的Python库。 1、百度谷歌搜索 明确自己的需求&#xff0c;用Python来干什么&#xff0c;力求简短明了。比如定位“数据分析”&#xff0c;然后去搜索关键词【Python数据分析第三方库】&#xf…

第二证券:港交所上市24周年 市值增长38倍

香港交易及结算所有限公司&#xff08;下称香港交易所&#xff09;于近来举办庆典活动&#xff0c;庆祝上市24周年。 据介绍&#xff0c;自2000年起&#xff0c;香港交易所逐步发展成为全球领先的商场营运机构&#xff0c;也成为连接中国内地与国际商场的主要桥梁。到2024年6月…

解决IMX6ULL GPIO扩展板PWM7/8中的pwm0/period后卡死问题

前言 本篇文章主要是记录解决百问网论坛上面设置 IMX6ULL GPIO扩展板PWM7/8中的pwm0/period后卡死问题&#xff0c;如下图&#xff1a; 一、查看原理图&#xff0c;找出对应引脚 在这里我们如何确定哪个扩展口中的引脚输出PWM波呢&#xff1f;我们可以通过查看原理图。 查看…

鸿蒙开发系统基础能力:【@ohos.inputMethodEngine (输入法服务)】

输入法服务 说明&#xff1a; 本模块首批接口从API version 8开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 导入模块 import inputMethodEngine from ohos.inputMethodEngine;inputMethodEngine 常量值。 系统能力&#xff1a;以下各项对应…

手慢无!限量奶茶免费领,千元大奖组队赢!

&#x1f680; AI 卡片大作战全新启动&#xff01;&#xff01;&#x1f552; 限时两周&#xff0c;组队狂欢&#xff01;&#x1f46b; 邀请好友&#xff0c;解锁免费奶茶福利&#xff01;&#x1f4b0; 学习卡片&#xff0c;赢取 1888 超级现金大奖心动不如行动&#xff0c;快…

试题与研究杂志试题与研究杂志社试题与研究编辑部2024年第16期目录

教海纵横 互动式教学模式在初中道德与法治课的应用探究 陈文海; 1-3 基于跨学科项目式学习的地理研学旅行课程设计——以“佛山梁园”为例 周红艳; 4-6 育人导向下道德与法治教学与社会实践活动的融合探索 李鹤群; 7-9 合作学习模式下的初中数学教学策略探究 张…

Mongo Express 未授权访问漏洞

【产品&&漏洞简述】 Mongo Express 是一个基于 Node.js 和 express 的开源的 MongoDB Web管理界面。Mongo Express存在未授权访问漏洞&#xff0c;攻击者可通过该漏洞获取用户信息或修改系统数据。 【资产测绘Query】 title"Home - Mongo Express" 【产品界…

OVS:网桥的状态:fail_mode模式

目录 1.创建一个普通的ovs网桥不做任何配置 2.检测fail_mode值&#xff0c;默认为空 3.创建netns并配置sto网桥的两个普通端口并配置IP信息 4.默认情况下的两个端口下挂两个虚拟机v3,v4天然通信-ping-ok 5.修改网桥的fail_mode为standalone,原来的通信没有影响 6.修改了…

win7使用vue-cli创建vue3工程

1.创建名为test的项目 vue create test 回车以后选择第三个,进行手动选择 2.选择配置 向下箭头表示下一个&#xff0c;空格表示*选中&#xff0c;按照我的选择来选即可&#xff0c;选完后回车 3.选择vue.js版本 上线箭头进行选择&#xff0c;选择后回车 4.选择不同的配置&#…

[DALL·E 2] Hierarchical Text-Conditional Image Generation with CLIP Latents

1、目的 CLIP DDPM进行text-to-image生成 2、数据 (x, y)&#xff0c;x为图像&#xff0c;y为相应的captions&#xff1b;设定和为CLIP的image和text embeddings 3、方法 1&#xff09;CLIP 学习图像和文本的embedding&#xff1b;在训练prior和decoder时固定该部分参数 2&a…

RabbitMQ的WorkQueues模型

WorkQueues模型 Work queues&#xff0c;任务模型。简单来说就是让多个消费者绑定到一个队列&#xff0c;共同消费队列中的消息。 当消息处理比较耗时的时候&#xff0c;可能生产消息的速度会远远大于消息的消费速度。长此以往&#xff0c;消息就会堆积越来越多&#xff0c;…