使用torch模拟 BMM int8量化计算。

在这里插入图片描述
使用torch模型BMM int8计算。
模拟:BMM->softmax->BMM 计算流程

import torch
import numpy as np
torch.manual_seed(777)
def int8_quantize_per_token(x: torch.Tensor, axis: int = -1, attns=False):
    if x.dtype != torch.float32:
        x = x.type(torch.float32)
    xmax = torch.abs(x)
    xmax = torch.max(xmax, dim=axis, keepdim=True)[0]
    scale = xmax / 127.0
    if not attns:
        # scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)
        pass
    else:
        # scale = torch.tensor(1 / 127.0, dtype=torch.float32)
        pass
    out = x / scale
    out = torch.round(out)
    out = torch.clamp(out, -128, 127)
    quantized_out = out.type(torch.int8)
    return quantized_out, scale

def int8_quantize_per_tensor(x, axis=0, attns=False):
    if x.dtype != torch.float32:
        x = x.type(torch.float32)
    xmax = torch.abs(x)
    xmax = torch.max(xmax, dim=-1, keepdim=True)[0]
    xmax = torch.max(xmax, dim=-2, keepdim=True)[0]
    scale = xmax / 127.0
    if not attns:
        # scale = torch.clamp(scale, 1e-5, np.finfo(np.float32).max)
        pass
    else:
        # scale = torch.tensor(1 / 127.0, dtype=torch.float32)
        pass
    out = x / scale
    out = torch.round(out)
    out = torch.clamp(out, -128, 127)
    quantized_out = out.type(torch.int8)
    return quantized_out, scale

def matmul_int8(key, query, value):
    key = key.permute([0, 1, 3, 2])
    query, q_s = int8_quantize_per_token(query)
    key, k_s = int8_quantize_per_token(key, -2)
    attention_scores = torch.matmul(query.type(torch.float32),
                                    key.type(torch.float32))

    scale = q_s * k_s
    attention_1 = torch.mul(attention_scores, scale)

    attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))
    attention_scores = torch.softmax(attention_scores, dim=-1)

    attention_scores_int8, attn_p_s = int8_quantize_per_token(attention_scores, attns=True)
    value, v_s = int8_quantize_per_token(value, -2)

    context = torch.matmul(attention_scores_int8.type(torch.float32),
                           value.type(torch.float32))

    scale = attn_p_s * v_s
    context = torch.mul(context, scale)

    return attention_1, context


def matmul_fp(key, query, value):
    key = key.permute([0, 1, 3, 2])
    attention_1 = torch.matmul(query.type(torch.float32),
                                    key.type(torch.float32))

    attention_scores = attention_1 / torch.sqrt(torch.tensor(32, dtype=torch.float32))
    attention_scores = torch.softmax(attention_scores, dim=-1)

    context = torch.matmul(attention_scores.type(torch.float32),
                           value.type(torch.float32))

    return attention_1, context


def mtx_similar1(arr1:np.ndarray, arr2:np.ndarray) ->float:
    '''
    计算矩阵相似度的一种方法。将矩阵展平成向量,计算向量的乘积除以模长。
    注意有展平操作。
    :param arr1:矩阵1
    :param arr2:矩阵2
    :return:实际是夹角的余弦值,ret = (cos+1)/2
    '''
    farr1 = arr1.ravel()
    farr2 = arr2.ravel()
    len1 = len(farr1)
    len2 = len(farr2)
    if len1 > len2:
        farr1 = farr1[:len2]
    else:
        farr2 = farr2[:len1]

    numer = np.sum(farr1 * farr2)
    denom = np.sqrt(np.sum(farr1**2) * np.sum(farr2**2))
    similar = numer / denom # 这实际是夹角的余弦值
    return  (similar+1) / 2     # 姑且把余弦函数当线性


if __name__ == "__main__":
    key = torch.randn((2, 6, 10, 32))
    value = torch.randn((2, 6, 10, 32))
    query = torch.randn((2, 6, 1, 32))
    i_key = key.clone().detach()
    i_value = value.clone().detach()
    i_query = query.clone().detach()

    fp_score, fp_context = matmul_fp(key, query, value)

    int8_score, int8_context = matmul_int8(i_key, i_query, i_value)

    similar1 = mtx_similar1(int8_score.cpu().detach().numpy(),
                           fp_score.cpu().detach().numpy())

    similar2 = mtx_similar1(int8_context.cpu().detach().numpy(),
                           fp_context.cpu().detach().numpy())

    print(similar1, similar2)

    np.testing.assert_allclose(
             fp_score.detach().cpu().numpy(),
             int8_score.detach().cpu().numpy(),
             rtol=1e-02, atol=1e-03)

    np.testing.assert_allclose(
             fp_context.detach().cpu().numpy(),
             int8_context.detach().cpu().numpy(),
             rtol=1e-02, atol=1e-03)

结论:
Per-token 精度优于per-tensor
BMM1 和 BMM2定点计算之后,输出误差较大

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/934643.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode 每日一题 219.存在重复元素 II

目录 问题描述 输入输出格式 示例 算法分析 过题图片 代码实现 复杂度分析 题目链接 总结 问题描述 给定一个整数数组nums和一个整数k&#xff0c;我们需要判断数组中是否存在两个不同的索引i和j&#xff0c;使得nums[i] nums[j]且|i - j| < k。如果存在这样的i和…

ragflow连不上ollama的解决方案

由于前期wsl默认装在C盘&#xff0c;后期部署好RagFlow后C盘爆红&#xff0c;在连接ollama的时候一直在转圈圈&#xff0c;问其他人没有遇到这种情况&#xff0c;猜测是因为内存不足无法加载模型导致&#xff0c;今天重新在E盘安装wsl 使用wsl装Ubuntu Win11 wsl-安装教程 如…

PR的选择与移动

选择工具 可以选择序列上的剪辑&#xff0c;如果需要多选可以按住shift键选中多个剪辑 CtrlA&#xff1a;可以进行全选 编组 选中多个剪辑后“右键-编组“可以将所选的剪辑连接在一起。这时单击任意剪辑都可以选中全部 向前选择轨道工具与向后选择轨道工具 向前选择轨道工具…

使用C#基于ADO.NET编写MySQL的程序

MySQL 是一个领先的开源数据库管理系统。它是一个多用户、多线程的数据库管理系统。MySQL 在网络上特别流行。MySQL 数据库可在大多数重要的操作系统平台上使用。它可在 BSD Unix、Linux、Windows 或 Mac OS 上运行。MySQL 有两个版本&#xff1a;MySQL 服务器系统和 MySQL 嵌入…

Python3中赋值运算符说明二

一. 简介 前面文章简单学习了 Python3中一些赋值运算符&#xff0c;文章如下&#xff1a; Python3中赋值运算符上篇-CSDN博客 本文继续学习 Python3中另外一些赋值运算符。 二. Python3 中赋值运算符 1. Python3 中赋值运算符 前一篇文章简单学习了 Python3 中的一些赋值…

如何在 Ubuntu 22.04 上安装和使用 Apache Kafka

简介 Apache Kafka是一个高性能、低延迟的分布式流处理平台&#xff0c;广泛用于构建实时数据管道和流式应用。本文将指导你如何在Ubuntu 22.04系统上快速部署Apache Kafka&#xff0c;让你体验到Kafka在处理大规模实时数据流方面的强大能力。通过本教程&#xff0c;你将学会如…

群控系统服务端开发模式-应用开发-自动退出发送邮件

一、修改Redis配置文件 将redis.conf里面的notify-keyspace-events参数对应的值改为Ex&#xff0c;具体代码如下&#xff1a; notify-keyspace-events Ex 二、创建控制台命令 在根目录下config文件夹下找到console.php文件修改&#xff0c;具体代码如下&#xff1a; <?p…

前端篇 -- jQuery详细教程

jQuery教程 jQuery官网1.1 jQuery的基本介绍1.2 jQuery 基本开发步骤1.3 jQuery对象和DOM对象 1.3.1 jQuery对象的基本介绍1.3.2 DOM对象转 jQuery对象1.3.3 jQuery对象转DOM对象 1.4 jQuery选择器 1.4.1 jQuery 基本选择器介绍1.4.2 基本选择器1.4.3 层次选择器1.4.4 基础过滤…

【数模学习笔记】模糊综合评价

声明&#xff1a;以下笔记中的图片均来自“数学建模学习交流”清风老师的课程ppt&#xff0c;仅用作学习交流使用 模糊综合评价 文章目录 模糊综合评价模糊数学经典集合和模糊集合的基本概念经典集合和特征函数模糊集合和隶属函数模糊集合的分类 隶属函数的确定方法方法一 模糊…

STM32F103单片机使用STM32CubeMX新建IAR工程步骤

打开STM32CubeMX软件&#xff0c;选择File 选择新建工程 在打开的窗口输入单片机型号 在右下角选择单片机型号&#xff0c;然后点右上角 start project&#xff0c;开始新建工程。 接下来设置调试接口&#xff0c;在左边System Core中选择 SYS&#xff0c;然后在右右边debu…

相机(Camera)硬件组成详解

简介&#xff1a;个人学习分享&#xff0c;如有错误&#xff0c;欢迎批评指正。 写在前面&#xff1a;可以去B站观看一些相机原理的视频来配合学习&#xff0c;这里推荐&#xff1a;推荐1&#xff0c;推荐2&#xff0c;推荐3 相机&#xff08;Camera&#xff09;是一种复杂的光…

String【Redis对象篇】

&#x1f3c6; 作者简介&#xff1a;席万里 ⚡ 个人网站&#xff1a; 文章目录 String1.String是什么&#xff1f;2.String怎么用&#xff1f;3.常用操作4.底层实现&#xff1f;5.总结&#xff08;重点&#xff09; String 1.String是什么&#xff1f; String就是字符串&…

乘上 SpringBoot 东风,广场舞团掀起律动热潮

2 系统开发环境 2.1 Java技术 Java是由Sun公司推出的一门跨平台的面向对象的程序设计语言。因为Java 技术具有卓越的通用性、高效性、健壮的安全性和平台移植性的特点&#xff0c;而且Java是开源的&#xff0c;拥有全世界最大的开发者专业社群&#xff0c;所以Java的发展迅速。…

组件开发的环境准备

目录​​​​​​​ node.js的安装 npm镜像源的修改 pnpm包管理器的安装&#xff08;全局安装&#xff09; 基于pnpm创建脚手架项目 node.js的安装 Node.js 是一个开源的、跨平台的 JavaScript 运行环境&#xff0c;能够在服务器端执行 JavaScript 代码。 a.下载与安装 …

【OpenCV】Canny边缘检测

理论 Canny 边缘检测是一种流行的边缘检测算法。它是由 John F. Canny 在 1986 年提出。 这是一个多阶段算法&#xff0c;我们将介绍算法的每一个步骤。 降噪 由于边缘检测易受图像中的噪声影响&#xff0c;因此第一步是使用 5x5 高斯滤波器去除图像中的噪声。我们在前面的章…

gitee常见命令

目录 1.本地分支重命名 2.更新远程仓库分支 3.为当前分支设置远程跟踪分支 4.撤销已经push远程的代码 5.idea->gitee的‘还原提交’ 需要和本地当前的代码解决冲突 解决冲突 本地工作区的差异代码显示 本地commit和push远程 6.idea->gitee的‘将当前分支重置到此…

Ultra-Fast-Lane-Detection复现、部署及训练

Ultra-Fast-Lane-Detection复现、训练及部署 一、复现二、训练三、部署 一、复现 Github下载源码&#xff1a;https://github.com/cfzd/Ultra-Fast-Lane-Detection &#xff08;1&#xff09;将GPU运算改为CPU运算&#xff1a;.cuda() -> .to(‘cpu’) test.py中33行&…

【Java计算机毕业设计】基于SSM+VUE宠物领养管理系统【源代码+数据库+LW文档+开题报告+答辩稿+部署教程+代码讲解】

源代码数据库LW文档&#xff08;1万字以上&#xff09;开题报告答辩稿 部署教程代码讲解代码时间修改教程 一、开发工具、运行环境、开发技术 开发工具 1、操作系统&#xff1a;Window操作系统 2、开发工具&#xff1a;IntelliJ IDEA或者Eclipse 3、数据库存储&#xff1a…

工业异常检测-CVPR2024-新的3D异常数据合成办法和自监督网络IMRNet

论文&#xff1a;https://arxiv.org/pdf/2311.14897v3.pdf 项目&#xff1a;https://github.com/chopper-233/anomaly-shapenet 这篇论文主要关注的是3D异常检测和定位&#xff0c;这是一个在工业质量检查中至关重要的任务。作者们提出了一种新的方法来合成3D异常数据&#x…

Linux-ubuntu环境配置

一&#xff0c;安装VWware&#xff0c;里面导入镜像文件 这些都是文件夹里面有的&#xff0c;然后对着正点原子视频安装就行&#xff0c;虚拟机的破解码&#xff0c;去百度搜一个能用就行&#xff0c;中间遇见俩问题。①乌班图里面不能上网&#xff0c;②插入U盘后&#xff0c;…