高性能 :DeepSeek-V3 inference 推理时反量化实现 fp8_cast_bf16

FP8 (8 bits) & FP16 (16 bits)

  • FP8 和 BF16 都是浮点数格式(floating-point formats),float通过科学计数法表示数据,float = [符号位+指数位+系数位]
FP8 (8 bits):SEEEMMMMFP16 (16 bits):SEEEEEMMMMMMMMMM
S (1 bit)S (1 bit)
EEE (3 bits)EEEEE (5 bits)
MMMM (4 bits)MMMMMMMMMM (10 bits)
  • FP8:1位符号位、3位指数位、4位尾数位。
  • FP16:1位符号位、5位指数位、10位尾数位。
特性FP8BF16
位数8 位16 位
存储需求非常低低(但高于 FP8)
精度精度非常低,仅适合低精度计算较低的精度,但比 FP8 精度高
范围较小的数值范围与 FP32 相似,具有广泛的数值范围
主要用途主要用于训练中的权重表示主要用于训练和推理,尤其适用于加速机器学习
优点极大的存储节省和计算加速适用于大规模深度学习模型,精度损失较小

fp8_cast_bf16

  • FP8到BF16转换: 主要通过weight_dequant函数将FP8权重转换为BF16格式。
import os  # 导入操作系统接口模块,用于文件和目录操作
import json  # 导入JSON模块,用于读取和写入JSON格式的数据
from argparse import ArgumentParser  # 导入ArgumentParser类,用于命令行参数解析
from glob import glob  # 导入glob模块,用于文件路径模式匹配
from tqdm import tqdm  # 导入tqdm模块,用于显示进度条

import torch  # 导入PyTorch库
from safetensors.torch import load_file, save_file  # 从safetensors库导入load_file和save_file函数

from kernel import weight_dequant  # 从kernel模块导入weight_dequant函数,用于权重解量化

def main(fp8_path, bf16_path):
    """
    将FP8权重转换为BF16并保存转换后的权重。

    该函数从指定的目录读取FP8权重,将其转换为BF16格式,
    并将转换后的权重保存到另一个指定的目录。它还更新了
    模型索引文件,反映出这些更改。

    参数:
    fp8_path (str): 存放FP8权重和模型索引文件的目录路径。
    bf16_path (str): 保存转换后的BF16权重的目录路径。

    异常:
    KeyError: 如果缺少所需的scale_inv张量,则会引发此异常。

    注意:
    - 假定FP8权重存储为safetensor文件。
    - 该函数缓存已加载的safetensor文件以优化内存使用。
    - 函数更新模型索引文件,删除对scale_inv张量的引用。
    """
    # 设置默认数据类型为bfloat16
    torch.set_default_dtype(torch.bfloat16)
    os.makedirs(bf16_path, exist_ok=True)  # 如果输出目录不存在,则创建它
    model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")  # 模型索引文件路径
    with open(model_index_file, "r") as f:
        model_index = json.load(f)  # 读取模型索引文件
    weight_map = model_index["weight_map"]  # 获取权重映射

    # 用于缓存已加载的safetensor文件
    loaded_files = {}
    fp8_weight_names = []  # 用于存储FP8权重的名称

    def get_tensor(tensor_name):
        """
        从缓存的safetensor文件中检索张量,如果没有缓存则从磁盘加载。

        参数:
            tensor_name (str): 要检索的张量名称。

        返回:
            torch.Tensor: 检索到的张量。

        异常:
            KeyError: 如果在safetensor文件中找不到指定的张量,则引发此异常。
        """
        file_name = weight_map[tensor_name]  # 获取该张量所在的文件名
        if file_name not in loaded_files:  # 如果该文件未加载
            file_path = os.path.join(fp8_path, file_name)  # 构建文件路径
            loaded_files[file_name] = load_file(file_path, device="cuda")  # 加载文件并缓存
        return loaded_files[file_name][tensor_name]  # 返回缓存的张量

    # 获取所有safetensor文件路径,并按字母排序
    safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
    safetensor_files.sort()

    # 遍历所有的safetensor文件
    for safetensor_file in tqdm(safetensor_files):
        file_name = os.path.basename(safetensor_file)  # 获取文件名
        current_state_dict = load_file(safetensor_file, device="cuda")  # 加载当前safetensor文件
        loaded_files[file_name] = current_state_dict  # 将文件缓存起来
        
        new_state_dict = {}  # 用于存储转换后的新权重字典
        for weight_name, weight in current_state_dict.items():  # 遍历文件中的所有权重
            if weight_name.endswith("_scale_inv"):  # 如果权重是scale_inv,跳过
                continue
            elif weight.element_size() == 1:  # 如果权重是FP8(即1字节)
                scale_inv_name = f"{weight_name}_scale_inv"  # 对应的scale_inv张量名称
                try:
                    # 尝试获取对应的scale_inv张量
                    scale_inv = get_tensor(scale_inv_name)
                    fp8_weight_names.append(weight_name)  # 将FP8权重名称记录下来
                    new_state_dict[weight_name] = weight_dequant(weight, scale_inv)  # 转换为BF16
                except KeyError:
                    # 如果没有找到scale_inv张量,则跳过转换
                    print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
                    new_state_dict[weight_name] = weight  # 保留原始权重
            else:
                new_state_dict[weight_name] = weight  # 如果不是FP8,直接保留原始权重
        
        # 保存转换后的权重
        new_safetensor_file = os.path.join(bf16_path, file_name)
        save_file(new_state_dict, new_safetensor_file)
        
        # 内存管理:保持仅2个最近使用的文件
        if len(loaded_files) > 2:
            oldest_file = next(iter(loaded_files))  # 获取最老的文件
            del loaded_files[oldest_file]  # 删除最老的文件
            torch.cuda.empty_cache()  # 清理缓存

    # 更新模型索引文件
    new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
    for weight_name in fp8_weight_names:  # 遍历所有FP8权重
        scale_inv_name = f"{weight_name}_scale_inv"  # 对应的scale_inv名称
        if scale_inv_name in weight_map:
            weight_map.pop(scale_inv_name)  # 从weight_map中删除scale_inv权重
    with open(new_model_index_file, "w") as f:
        json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)  # 保存更新后的索引文件

if __name__ == "__main__":
    # 设置命令行参数解析
    parser = ArgumentParser()
    parser.add_argument("--input-fp8-hf-path", type=str, required=True)  # 输入FP8权重路径
    parser.add_argument("--output-bf16-hf-path", type=str, required=True)  # 输出BF16权重路径
    args = parser.parse_args()
    main(args.input_fp8_hf_path, args.output_bf16_hf_path)  # 调用主函数进行转换

weight_dequant

  • 引入包,建议先阅读Triton向量相加 的基础示例以理解Triton的工作方式。
from typing import Tuple
import torch
import triton
import triton.language as tl # Triton 语言(Triton Language)允许用户在 GPU 上编写高效的并行计算内核https://github.com/triton-lang/triton
from triton import Config
  • weight_dequant 函数用于将量化的权重张量(x)进行反量化处理,恢复到浮动值。以下是注释的解释:
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
    """
    Dequantizes the given weight tensor using the provided scale tensor.

    Args:
        x (torch.Tensor): The quantized weight tensor of shape (M, N).
        s (torch.Tensor): The scale tensor of shape (M, N).
        block_size (int, optional): The block size to use for dequantization. Defaults to 128.

    Returns:
        torch.Tensor: The dequantized weight tensor of the same shape as `x`.

    Raises:
        AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
    """
    assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' # 确保输入张量是连续的(即内存布局连续)
    
    
    assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' # 确保输入张量 x 和 s 都是二维的
    
    M, N = x.size() # 获取输入张量 x 的尺寸 M (行数) 和 N (列数)

    # 创建一个和 x 形状相同的新张量 y,用来保存反量化后的结果
    y = torch.empty_like(x, dtype=torch.get_default_dtype())

    # 定义一个 grid 函数来计算 triton 内核所需的网格大小
    # triton.cdiv 是向上取整除法,用来确保我们分配足够的线程处理每个块
    grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))

    # 调用 triton 内核 `weight_dequant_kernel` 进行反量化操作
    # 将 quantized weight `x` 和 scale `s` 与结果张量 `y` 一起传递给内核
    # `M`, `N`, `block_size` 作为额外的参数传递
    weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)

    # 返回反量化后的张量 y
    return y
  • 计算网格大小:
    • grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])): 使用 triton.cdiv 来计算块的数量。triton.cdiv 是向上取整除法,用于确定每个维度需要多少个块来处理 MN 大小的数据。meta['BLOCK_SIZE']) 是每个块处理的元素数量(默认值为 128)。

weight_dequant_kernel

  • Nvidia GPU CUDA使用grid、block、thread进行索引。
  • 实现反量化的核函数(模型可能使用的是LSQ(Learned Step Quantization)Quantization,仅有量化步长参数),通过weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)调用:
@triton.jit  # 使用 Triton 编译器将此函数编译为高效的 GPU 内核
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
    """
    Dequantizes weights using the provided scaling factors and stores the result.

    Args:
        x_ptr (tl.pointer): Pointer to the quantized weights.
        s_ptr (tl.pointer): Pointer to the scaling factors.
        y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
        M (int): Number of rows in the weight matrix.
        N (int): Number of columns in the weight matrix.
        BLOCK_SIZE (tl.constexpr): Size of the block for tiling.

    Returns:
        None
    """
    
    # 获取当前线程在程序中的编号
    pid_m = tl.program_id(axis=0)  # 获取当前行维度上的线程编号,pid_m 和 pid_n 的范围由矩阵的尺寸 M 和 N,以及线程块的大小 BLOCK_SIZE 决定
    pid_n = tl.program_id(axis=1)  # 获取当前列维度上的线程编号,pid_m 的值从 0 到 ceil(M / BLOCK_SIZE) - 1
    
    # 计算矩阵列的块数
    n = tl.cdiv(N, BLOCK_SIZE)  # 使用向上取整除法计算列方向上的块数
    
    # 计算当前线程块在行和列方向上的偏移量
    offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)  # 当前块在行方向的偏移量
    offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)  # 当前块在列方向的偏移量
    # 将行和列的偏移量组合成一个二维的索引数组
    offs = offs_m[:, None] * N + offs_n[None, :]  # 将行和列偏移量结合,得到每个元素的全局索引,offs_m[:, None]形状会变成 (BLOCK_SIZE, 1),相加广播后变为(BLOCK_SIZE, BLOCK_SIZE)
    
    # 使用掩码保证我们不会超出矩阵的边界
    mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)  # 掩码,确保线程不会访问超出矩阵范围的数据
    
    # 加载量化后的权重数据(量化后的值)
    x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)  # 从内存中加载量化后的数据,并转换为 float32 类型
    
    # 加载缩放因子
    s = tl.load(s_ptr + pid_m * n + pid_n)  # 从内存中加载对应的缩放因子,s_ptr是指向缩放因子数组的指针,pid_m * n + pid_n计算出当前线程块在缩放因子数组中的位置。
    
    # 执行去量化操作:去量化 = 量化值 * 缩放因子
    y = x * s  # 去量化的计算公式
    
    # 将去量化后的数据存储到输出缓存中
    tl.store(y_ptr + offs, y, mask=mask)  # 将去量化后的值存储到输出内存中,使用掩码确保数据存储在合法的范围内,`offs` 是索引,`mask=mask` 确保只有合法的元素被存储

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/965259.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Elasticsearch】range aggregation

Elasticsearch 的Range Aggregation是一种强大的桶聚合&#xff08;Bucket Aggregation&#xff09;工具&#xff0c;用于将文档按照数值范围进行分组&#xff0c;从而实现对数据的分段分析。以下是关于 Range Aggregation 的详细说明&#xff1a; 1.Range Aggregation 的基本概…

AI测试工程师成长指南:以DeepSeek模型训练为例

目录 引言&#xff1a;AI测试工程师的使命与挑战成长日记&#xff1a;从测试小白到AI测试专家核心能力&#xff1a;AI测试工程师的必备素养知识体系&#xff1a;技术栈与技能图谱AI测试工具全景&#xff1a;以DeepSeek为核心的工具链实战训练模式&#xff1a;以DeepSeek模型迭…

Spring Boot整合MQTT

MQTT是基于代理的轻量级的消息发布订阅传输协议。 1、下载安装代理 进入mosquitto下载地址&#xff1a;Download | Eclipse Mosquitto&#xff0c;进行下载&#xff0c;以win版本为例 下载完成后&#xff0c;在本地文件夹找到下载的代理安装文件 使用管理员身份打开安装 安装…

Elasticsearch 开放推理 API 增加了 Azure AI Studio 支持

作者&#xff1a;来自 Elastic Mark Hoy Elasticsearch 开放推理 API 现已支持 Azure AI Studio。在此博客中了解如何将 Azure AI Studio 功能与 Elasticsearch 结合使用。 作为我们持续致力于为 Microsoft Azure 开发人员提供他们选择的工具的一部分&#xff0c;我们很高兴地宣…

【EdgeAI实战】(2)STM32 AI 扩展包的安装与使用

【EdgeAI实战】&#xff08;1&#xff09;STM32 边缘 AI 生态系统 【EdgeAI实战】&#xff08;2&#xff09;STM32 AI 扩展包的安装与使用 【EdgeAI实战】&#xff08;2&#xff09;STM32 AI 扩展包的安装与使用 1. STM32Cube.AI 简介1.1 STM32Cube.AI 简介1.2 X-CUBE-AI 内核引…

MySQL的 MVCC详解

MVCC是多版本并发控制&#xff0c;允许多个事务同时读取和写入数据库&#xff0c;而无需互相等待&#xff0c;从而提高数据库的并发性能。 在 MVCC 中&#xff0c;数据库为每个事务创建一个数据快照。每当数据被修改时&#xff0c;MySQL不会立即覆盖原有数据&#xff0c;而是生…

【电脑系统】电脑突然(蓝屏)卡死发出刺耳声音

文章目录 前言问题描述软件解决方案尝试硬件解决方案尝试参考文献 前言 在 更换硬盘 时遇到的问题&#xff0c;有时候只有卡死没有蓝屏 问题描述 更换硬盘后&#xff0c;电脑用一会就卡死&#xff0c;蓝屏&#xff0c;显示蓝屏代码 UNEXPECTED_STORE_EXCEPTION 软件解决方案…

SpringAI系列 - 使用LangGPT编写高质量的Prompt

目录 一、LangGPT —— 人人都可编写高质量 Prompt二、快速上手2.1 诗人 三、Role 模板3.1 Role 模板3.2 Role 模板使用步骤3.3 更多例子 四、高级用法4.1 变量4.2 命令4.3 Reminder4.4 条件语句4.5 Json or Yaml 方便程序开发 一、LangGPT —— 人人都可编写高质量 Prompt La…

为什么在springboot中使用autowired的时候它黄色警告说不建议使用字段注入

byType找到多种实现类导致报错 Autowired: 通过byType 方式进行装配, 找不到或是找到多个&#xff0c;都会抛出异常 我们在单元测试中无法进行字段注入 字段注入通常是 private 修饰的&#xff0c;Spring 容器通过反射为这些字段注入依赖。然而&#xff0c;在单元测试中&…

Ubuntu24登录PostgreSql数据库的一般方法

命令格式如 psql -U user -d db 或者 sudo psql -U user -d db 修改配置 /etc/postgresql/16/main/postgresql.conf 改成md5&#xff0c;然后重新启动pgsql sudo systemctl restart postgresql

ESP-Skainet智能语音助手,ESP32-S3物联网方案,设备高效语音交互

在科技飞速发展的今天&#xff0c;智能语音助手正逐渐渗透到我们生活的方方面面&#xff0c;而智能语音助手凭借其卓越的技术优势&#xff0c;成为了智能生活领域的一颗璀璨明星。 ESP-Skainet智能语音助手的强大之处在于其支持唤醒词引擎&#xff08;WakeNet&#xff09;、离…

数据结构与算法学习笔记----博弈论

# 数据结构与算法学习笔记----博弈论 author: 明月清了个风 first publish time: 2025.2.6 ps⭐️包含了博弈论中的两种问题Nim游戏和SG函数&#xff0c;一共四道例题&#xff0c;给出了具体公式的证明过程。 Acwing 891. Nim游戏 [原题链接](891. Nim游戏 - AcWing题库) 给…

Go 语言 | 入门 | 先导课程

快速入门 1.第一份代码 先检查自己是否有正确下载 Go&#xff0c;如果没有直接去 Go 安装 进行安装。 # 检查是否有 Go $ go version go version go1.23.4 linux/amd64然后根据 Go 的入门教程 开始进行学习。 # 初始化 Go 项目 $ mkdir example && cd example # Go…

ChatGPT提问技巧:行业热门应用提示词案例--咨询法律知识

ChatGPT除了可以协助办公&#xff0c;写作文案和生成短视频脚本外&#xff0c;和还可以做为一个法律工具&#xff0c;当用户面临一些法律知识盲点时&#xff0c;可以向ChatGPT咨询获得解答。赋予ChatGPT专家的身份&#xff0c;用户能够得到较为满意的解答。 1.咨询法律知识 举…

WPS中解除工作表密码保护(忘记密码)

1.下载vba插件 项目首页 - WPS中如何启用宏附wps.vba.exe下载说明分享:WPS中如何启用宏&#xff1a;附wps.vba.exe下载说明本文将详细介绍如何在WPS中启用宏功能&#xff0c;并提供wps.vba.exe文件的下载说明 - GitCode 并按照步骤安装 2.wps中点击搜索&#xff0c;输入开发…

【ThreeJS 01】了解 WebGL 以及 ThreeJS

文章目录 01 介绍02 什么是 WebGL&#xff0c;为什么用 ThreeJS什么是 WebGL&#xff1f;Three.js 来帮忙 01 介绍 这个课程的主讲人是 Bruno Simon&#xff0c; 这是他的作品集 他还做了一些有趣的项目&#xff1a; https://my-room-in-3d.vercel.app https://organic-sphe…

SpringBoot+Dubbo+zookeeper 急速入门案例

项目目录结构&#xff1a; 第一步&#xff1a;创建一个SpringBoot项目&#xff0c;这里选择Maven项目或者Spring Initializer都可以&#xff0c;这里创建了一个Maven项目&#xff08;SpringBoot-Dubbo&#xff09;&#xff0c;pom.xml文件如下&#xff1a; <?xml versio…

Unity Shader Graph 2D - 使用DeepSeek协助绘制一个爱心

最近十分流行使用DeepSeek AI&#xff0c;于是想尝试着能不能用DeepSeek来帮助我实现一些Shader Graph效果&#xff0c;正好之前看到了爱心图形&#xff0c;就说干脆用DeepSeek来告诉我怎么使用Shader Graph来绘制一个爱心。 问DeepSeek怎么绘制爱心 首先打开DeepSeek的网站&a…

如何正确配置您的WordPress邮件设置

在运营WordPress网站时&#xff0c;确保邮件能够顺利发送和接收是非常重要的。无论是通知、确认邮件&#xff0c;还是营销邮件&#xff0c;邮件的可靠性会直接影响用户体验。许多站长常常会遇到邮件无法送达、被标记为垃圾邮件等问题。要解决这些问题&#xff0c;使用SMTP是一个…

MySQL调优01 - 单库调优思想

单库调优 文章目录 单库调优一&#xff1a;系统中性能优化的核心思维二&#xff1a;MySQL性能优化实践1&#xff1a;连接层的优化1.1&#xff1a;连接数是越大越好吗&#xff1f;1.2&#xff1a;偶发高峰类业务的连接数配置1.3&#xff1a;分库分表情况下的连接数配置1.4&#…