使用torch模型BMM int8计算。
模拟:BMM->softmax->BMM 计算流程
import torch
import numpy as np
torch.manual_seed(777)
def int8_quantize_per_token(x: torch.Tensor, axis: int = -1, attns=False):
if x.dtype != torch.float32:
x = x.type(torch.float32)
xmax = torch.abs(x)
xmax = torch.max(xmax, dim=axis, keepdim=True)[0]
scale = xmax / 127.0
if not attns:
# scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)
pass
else:
# scale = torch.tensor(1 / 127.0, dtype=torch.float32)
pass
out = x / scale
out = torch.round(out)
out = torch.clamp(out, -128, 127)
quantized_out = out.type(torch.int8)
return quantized_out, scale
def int8_quantize_per_tensor(x, axis=0, attns=False):
if x.dtype != torch.float32:
x = x.type(torch.float32)
xmax = torch.abs(x)
xmax = torch.max(xmax, dim=-1, keepdim=True)[0]
xmax = torch.max(xmax, dim=-2, keepdim=True)[0]
scale = xmax / 127.0
if not attns:
# scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)
pass
else:
# scale = torch.tensor(1 / 127.0, dtype=torch.float32)
pass
out = x / scale
out = torch.round(out)
out = torch.clamp(out, -128, 127)
quantized_out = out.type(torch.int8)
return quantized_out, scale
def matmul_int8(key, query, value):
key = key.permute([0, 1, 3, 2])
query, q_s = int8_quantize_per_token(query)
key, k_s = int8_quantize_per_token(key, -2)
attention_scores = torch.matmul(query.type(torch.float32),
key.type(torch.float32))
scale = q_s * k_s
attention_1 = torch.mul(attention_scores, scale)
attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))
attention_scores = torch.softmax(attention_scores, dim=-1)
attention_scores_int8, attn_p_s = int8_quantize_per_token(attention_scores, attns=True)
value, v_s = int8_quantize_per_token(value, -2)
context = torch.matmul(attention_scores_int8.type(torch.float32),
value.type(torch.float32))
scale = attn_p_s * v_s
context = torch.mul(context, scale)
return attention_1, context
def matmul_fp(key, query, value):
key = key.permute([0, 1, 3, 2])
attention_1 = torch.matmul(query.type(torch.float32),
key.type(torch.float32))
attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))
attention_scores = torch.softmax(attention_scores, dim=-1)
context = torch.matmul(attention_scores.type(torch.float32),
value.type(torch.float32))
return attention_1, context
def mtx_similar1(arr1:np.ndarray, arr2:np.ndarray) ->float:
'''
计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。
注意有展平操作。
:param arr1:矩阵1
:param arr2:矩阵2
:return:实际是夹角的余弦值,ret = (cos+1)/2
'''
farr1 = arr1.ravel()
farr2 = arr2.ravel()
len1 = len(farr1)
len2 = len(farr2)
if len1 > len2:
farr1 = farr1[:len2]
else:
farr2 = farr2[:len1]
numer = np.sum(farr1 * farr2)
denom = np.sqrt(np.sum(farr1**2) * np.sum(farr2**2))
similar = numer / denom # 这实际是夹角的余弦值
return (similar+1) / 2 # 姑且把余弦函数当线性
if __name__ == "__main__":
key = torch.randn((2, 6, 10, 32))
value = torch.randn((2, 6, 10, 32))
query = torch.randn((2, 6, 1, 32))
i_key = key.clone().detach()
i_value = value.clone().detach()
i_query = query.clone().detach()
fp_score, fp_context = matmul_fp(key, query, value)
int8_score, int8_context = matmul_int8(i_key, i_query, i_value)
similar1 = mtx_similar1(int8_score.cpu().detach().numpy(),
fp_score.cpu().detach().numpy())
similar2 = mtx_similar1(int8_context.cpu().detach().numpy(),
fp_context.cpu().detach().numpy())
print(similar1, similar2)
np.testing.assert_allclose(
fp_score.detach().cpu().numpy(),
int8_score.detach().cpu().numpy(),
rtol=1e-02, atol=1e-03)
np.testing.assert_allclose(
fp_context.detach().cpu().numpy(),
int8_context.detach().cpu().numpy(),
rtol=1e-02, atol=1e-03)
结论:
Per-token 精度优于per-tensor
BMM1 和 BMM2定点计算之后,输出误差较大