图片速览 BitNet: 1-bit LLM

输入数据

  • 模型使用absmax 量化方法进行b比特量化,将输入量化到 [ − Q b , Q b ] ( Q b = 2 b − 1 ) \left[-Q_{b},Q_{b}\right](Q_{b}=2^{b-1}) [Qb,Qb](Qb=2b1)
    x ~ = Q u a n t ( x ) = C l i p ( x × Q b γ , − Q b + ϵ , Q b − ϵ ) , Clip ⁡ ( x , a , b ) = max ⁡ ( a , min ⁡ ( b , x ) ) , γ = ∣ ∣ x ∣ ∣ ∞ , \widetilde{x}=\mathrm{Quant}(x)=\mathrm{Clip}(x\times\frac{Q_b}{\gamma},-Q_b+\epsilon,Q_b-\epsilon),\\ \operatorname{Clip}(x,a,b)=\max(a,\min(b,x)),\quad\gamma=||x||_\infty, x =Quant(x)=Clip(x×γQb,Qb+ϵ,Qbϵ),Clip(x,a,b)=max(a,min(b,x)),γ=∣∣x,

  • 其中 ε 是一个小的浮点数,可防止在执行截断时溢出。

// https://github.com/kyegomez/BitNet/blob/main/bitnet/bitbnet_b158.py
def absmean_quantize_weights(weights):
    """
    Quantizes the weights to -1, 0, or +1 using an absmean quantization function.

    Parameters:
    - weights (Tensor): The weights of a neural network layer.

    Returns:
    - Tensor: The quantized weights.
    """
    # Calculate the average absolute value (γ) of the weights
    gamma = torch.mean(torch.abs(weights))
    
    # Scale weights by γ and round to the nearest integer among {-1, 0, +1}
    quantized_weights = torch.clamp(torch.round(weights / gamma), min=-1, max=1)
    
    return quantized_weights

权重

  • 权重 W 的二值化可以公式化为:

α = 1 n m ∑ i j W i j W ~ = S i g n ( W − α ) , Sign ⁡ ( W i j ) = { + 1 , if W i j > 0 , − 1 , if W i j ≤ 0 , \\ \alpha=\frac1{nm}\sum_{ij}W_{ij} \\ \widetilde{W}=\mathrm{Sign}(W-\alpha),\\ \left.\operatorname{Sign}(W_{ij})=\left\{\begin{array}{ll}+1,&\quad\text{if}W_{ij}>0,\\-1,&\quad\text{if}W_{ij}\leq0,\end{array}\right.\right. α=nm1ijWijW =Sign(Wα),Sign(Wij)={+1,1,ifWij>0,ifWij0,

在这里插入图片描述

矩阵乘法

  • 使用上述量化方程,矩阵乘法可以写成:

y = W ~ x ~ y=\widetilde W\widetilde{x} y=W x

  • 为了保持量化后的方差,我们在激活量化之前引入了一个 LayerNorm函数。这样,输出 y 的方差就估计为 1

y = W ~ x ~ = W ~ Quant ( LN ( x ) ) × β γ Q b y=\widetilde{W}\widetilde{x}=\widetilde{W}\text{Quant}(\text{LN}(x))\times\frac{\beta\gamma}{Q_b} y=W x =W Quant(LN(x))×Qbβγ
L N ( x ) = x − E ( x ) V a r ( x ) + ϵ , β = 1 n m ∥ W ∥ 1 \mathrm{LN}(x)=\frac{x-E(x)}{\sqrt{\mathrm{Var}(x)+\epsilon}},\quad\beta=\frac1{nm}\|W\|_1 LN(x)=Var(x)+ϵ xE(x),β=nm1W1

在这里插入图片描述

// https://github.com/kyegomez/BitNet/blob/main/bitnet/bitlinear.py
import torch
from torch import Tensor, nn


class BitLinear(nn.Linear):
    """
    BitLinear is a custom linear layer that performs binarization of weights and quantization of activations
    in a group-wise manner.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True.
        num_groups (int, optional): Number of groups to divide the weights and activations into. Default is 1.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        num_groups: int = 1,
        b: int = 8,
    ):
        super().__init__(in_features, out_features, bias)
        self.in_features = in_features
        self.out_features = out_features
        self.b = b
        self.num_groups = num_groups
        self.eps = 1e-5
        self.norm = nn.LayerNorm(in_features)

    def ste(self, x):
        """
        Applies the sign function for binarization and uses Straight-Through Estimator (STE) during backward pass.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Binarized tensor.
        """
        binarized_x = torch.sign(x)
        binarized_x = (binarized_x - x).detach() + x
        return binarized_x

    def binarize_weights_groupwise(self):
        """
        Binarizes the weights of the layer in a group-wise manner using STE.

        Returns:
            Tensor: Binarized weights tensor.
        """
        group_size = self.weight.shape[0] // self.num_groups
        binarized_weights = torch.zeros_like(self.weight)

        for g in range(self.num_groups):
            start_idx = g * group_size
            end_idx = (g + 1) * group_size
            weight_group = self.weight[start_idx:end_idx]

            alpha_g = weight_group.mean()
            binarized_weights[start_idx:end_idx] = self.ste(weight_group - alpha_g)

        return binarized_weights

    def quantize_activations_groupwise(self, x):
        """
        Quantizes the activations of the layer in a group-wise manner.

        Args:
            x (Tensor): Input tensor.
            b (int, optional): Number of bits for quantization. Default is 8.

        Returns:
            Tensor: Quantized activations tensor.
        """
        Q_b = 2 ** (self.b - 1)

        group_size = x.shape[0] // self.num_groups
        quantized_x = torch.zeros_like(x)

        for g in range(self.num_groups):
            start_idx = g * group_size
            end_idx = (g + 1) * group_size
            activation_group = x[start_idx:end_idx]

            gamma_g = activation_group.abs().max()
            quantized_x[start_idx:end_idx] = torch.clamp(
                activation_group * Q_b / (gamma_g + self.eps),
                -Q_b + self.eps,
                Q_b - self.eps,
            )

        return quantized_x
    
    def dequantize_activations_groupwise(self, x):
        """
        Dequantizes the activations of the layer in a group-wise manner.

        Args:
            x (Tensor): Quantized input tensor.
            b (int, optional): Number of bits used during the quantization. Default is 8.

        Returns:
            Tensor: Dequantized activations tensor.
        """
        Q_b = 2 ** (self.b - 1)
        dequantized_x = torch.zeros_like(x)
        for g in range(self.num_groups):
            start_idx = g * x.shape[0] // self.num_groups
            end_idx = (g + 1) * x.shape[0] // self.num_groups
            quantized_group = x[start_idx:end_idx]
            gamma_g = quantized_group.abs().max()
            dequantized_x[start_idx:end_idx] = quantized_group * gamma_g / Q_b
        return dequantized_x

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the BitLinear layer.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor.
        """
        # Normalize input
        x = self.norm(x)

        # Binarize weights and quantize activations
        binarized_weights = self.binarize_weights_groupwise()

        # Perform linear transformation
        output = torch.nn.functional.linear(x, binarized_weights, self.bias)

        # Quantize activations
        output = self.quantize_activations_groupwise(output)
        
        # Dequantize activations
        output = self.dequantize_activations_groupwise(output)

        # Return output
        return output



# Example usage
bitlinear = BitLinear(10, 5, num_groups=2, b=8)
input_tensor = torch.randn(5, 10)  # Example input tensor
output = bitlinear(input_tensor)
print(output)  # Example output tensor

CG

  • 【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM

  • BitNet: Scaling 1-bit Transformers for Large Language Models

  • The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits

  • Implementation of “BitNet: Scaling 1-bit Transformers for Large Language Models” in pytorch

  • DB-LLM: Accurate Dual-Binarization for Efficient LLMs

  • 如何看待微软提出的BitNet b1.58?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/441155.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

供应链管理系统(SCM):得供应链得天下不是空话。

2023-08-26 15:51贝格前端工场 Hi,我是贝格前端工场,优化升级各类管理系统的界面和体验,是我们核心业务之一,欢迎老铁们评论点赞互动,有需求可以私信我们 一、供应链对于企业的重要性 供应链对企业经营的重要性不可…

【ViT】Vision Transformer的实现01 patch embedding

对于224*224的图像,将它输入到Transformer里面,就需要将图像展开成一系列的token, 如果逐像素视为token进行注意力的计算,难免计算量太大,因此一个更加合理的想法是将图像划分为一个个的patch 将每个patch进行embeddin…

Vue3+element-plus复杂表单分组处理

一、为什么表单要分组处理? 方便表单字段的复用:例如,你的表单有十个字段会在很多的表单都会用到,那么表单则需要进行分组进行表单复用;实现不同角色的表单权限控制:例如一个表单有60个字段,角…

3DEXPERIENCE Works八大核心优势分析

云技术正在加速普及,助力各行各业数字化转型。根据IDC 2023年12月发布的报告,2023年全球云计算市场规模达到3329亿美元,同比增长19.4%。其中,公有云市场规模达到2587亿美元,同比增长21.5%;私有云市场规模达到742亿美元…

倒计时最后1天!AutoMQ x 阿里云云原生创新论坛精彩预告

3月9日(本周六)“AutoMQ x 阿里云云原生创新论坛”就要与大家见面了,让我们一起来看看本次论坛都有哪些精彩的议题!文末附有参会交通指南和直播预约链接。 精彩剧透 01 AutoMQ:加速云原生创新,助力大数据…

Java集合面试题(day 02)

📑前言 本文主要是【JAVA】——Java集合面试题的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 🌄每日一句&am…

支小蜜校园防欺凌系统听到声音之后会自动识别吗

在校园安全领域,特别是在预防和应对欺凌问题上,校园防欺凌系统作为新兴的技术应用,正在受到越来越多的关注和探索。那么当这样的系统听到声音之后,它是否能够自动识别并作出相应反应呢?本文将围绕这一问题展开探讨。 …

protobufjs使用教程,支持proto文件打包成typescript或javascript脚本

官方链接:https://docs.cocos.com/creator/manual/zh/scripting/modules/example.html 第一步,安装nodejs。(自行安装) 安装教程可参考 https://www.runoob.com/nodejs/nodejs-install-setup.html 第二步,创建cocos…

雷赛控制卡正负限位的信号灯不亮问题处理

雷赛控制卡正负限位的信号灯不亮问题处理 正负限位IO映射:这个映射要和轴号对映。 回零设置中的IO映射也是一样的设置。 如下图:3轴的IO映射都选3。1轴的IO映射都选1

c++的STL(2)-- vector容器

目录 1. 默认构造 代码: 相关知识点: 2. 有参构造函数 以及 使用{}初始化对象 代码: 相关知识点: 3. vector容器在尾部添加和删除元素 代码: 使用push_back()和pop_back()进行尾部元素的添加和删除 相关知识点: 代码: 使用emplace_back在尾部添…

机器学习——神经网络压缩

神经网络压缩 需要部署,设备内存和计算能力有限,需要进行模型压缩,在设备上运行的好处是低延迟,隐私性。 目录 不考虑硬件问题,只考虑通过软件算法优化。 修剪网络 参数过多或者没有用的参数,可以将其剪…

MRI基础--k空间

k空间定义 k空间是表示 MR 图像中空间频率的数字数组。 k空间物理意义 k 空间的单元通常显示在主轴 kx 和 ky 的矩形网格上。 k 空间的 kx 和 ky 轴对应于图像的水平 (x) 和垂直 (y) 轴。然而,k 轴表示 x 和 y 方向上的空间频率而不是位置。 k 空间中的各个点 (kx,ky) 与图像…

R语言 | 复数 相关函数

问题 大家好&#xff0c;我有一个问题&#xff0c;我看到一个函数如下&#xff1a; L2_distance <- function(A, B){rowA <- apply(A*A, 1, sum)matrixA <- matrix(rep(rowA, eachlength(rowA)), nrowlength(rowA), byrowT)rowB <- apply(B*B, 1, sum)matrixB &l…

【教程】Kotlin语言学习笔记(四)——方法(持续更新)

写在前面&#xff1a; 如果文章对你有帮助&#xff0c;记得点赞关注加收藏一波&#xff0c;利于以后需要的时候复习&#xff0c;多谢支持&#xff01; 【Kotlin语言学习】系列文章 第一章 《认识Kotlin》 第二章 《数据类型》 第三章 《数据容器》 第四章 《方法》 文章目录 【…

小白跟做江科大51单片机之AD/DA

1.看原理图找接口 2.看时序图编写读取数据代码 XPT2046.c代码 #include <REGX52.H> //引脚定义 sbit XPY2046_DINP3^4; sbit XPY2046_CSP3^5; sbit XPY2046_DCLKP3^6; sbit XPY2046_DOUTP3^7; unsigned int XPT2046_ReadAD(unsigned char Command) { unsigned char …

多多关键字API php java Python

多多关键字API接口广泛应用于商家进行市场分析、竞品分析、关键词优化等场景。商家可以通过分析关键词数据&#xff0c;了解用户需求&#xff0c;制定针对性的营销策略&#xff0c;提高产品的曝光率和转化率。 多多-item_seach-通过关键字搜索商品列表 公共参数 获取key和秘钥…

开发知识点-C++之win32与NT内核

win32 Windows MFC编程 常用API汇总EnumWindows()函数UpdateData()函数static与 单例 设计模式函数原型:BOOL WINAPI SetConsoleTitle(__in LPCTSTR lpConsoleTitle);HWND 是一个基本类型 表示窗口句柄FindWindow函数SendMessage函数 将指定的消息发送到一个或多个窗口PostMes…

百度地图城市点位数据下载并转换

概述 在浏览百度地图开放平台的时候&#xff0c;发现有个资源下载页面&#xff0c;里面有个城市中心点位和百度地图行政区划adcode映射表数据&#xff0c;这是一个经常使用到的数据&#xff0c;本文实现将这个数据转换为geojson&#xff0c;并借助QGIS转换为经纬度坐标或火星坐…

手写简易操作系统(一)--环境配置

本专栏是我新开设的一个学术专栏&#xff0c;旨在全面介绍手写操作系统的相关内容。其中包括实模式向保护模式的过渡、锁机制、信号量操作、内存分配、硬盘驱动、文件系统、简单shell和管道等操作系统核心知识。该专栏旨在为有意开发自己操作系统的研究人员提供指导与帮助。作为…

代码随想录刷题笔记-Day31

1. 分发饼干 455. 分发饼干https://leetcode.cn/problems/assign-cookies/ 假设你是一位很棒的家长&#xff0c;想要给你的孩子们一些小饼干。但是&#xff0c;每个孩子最多只能给一块饼干。 对每个孩子 i&#xff0c;都有一个胃口值 g[i]&#xff0c;这是能让孩子们满足胃口…