一 基本语法
1 torch中tensor的声明
x = torch.tensor([[1,2, 1, 1, 1, 1, 1, 1],[2,2,2,2,2,2,2,2]],device='cuda')
声明的时候有的时候需要指出数据的类型,不然在kernel中数据类型无法匹配
x = torch.tensor([1,2,1,1,1,1,1,1],dtype = torch.int32,device='cuda')
2 idx
idx表示的是数据在sram中的索引,如果idx为(),则表示为只有一个数据
二 triton中函数
1 sum
output = tl.sum(x,axis = 0)
如果输入是torch中的声明的话,则输出为
这就是一个reduce的过程,将x轴的数据全部相加,得到一个数字
2 乘法
2.1 矩阵乘
dot
2.2 卷积乘(叉乘)
*
2.3 具体说明
当