【模型压缩】原理及实例

在移动智能终端品类越发多样的时代,为了让模型可以顺利部署在算力和存储空间都受限的移动终端,对模型进行压缩尤为重要。模型压缩(model compression)可以降低神经网络参数量,减少延迟时间,从而实现提高神经网络推理速度、节省存储空间等目的。

一.量化

量化是指将模型权重参数用更少的比特数存储,以此来减少模型的存储空间和算力消耗。

1.基本原理

(1) 量化感知训练

Quantization-aware Training,QAT在训练过程中模拟量化过程,数据虽然表示为float32,但实际的值的间隔却会受到量化参数的设置。

QAT的具体流程如下:

1)初始化:设置权重和激活值范的范围q_{min}q_{max}的初始值;

2)构建模拟量化网络:在需要量化的权重和激活值后插入伪量化算子;

3)量化训练:重复执行以下步骤直至网络收敛(计算量化网络层的权重和激活值的范围q_{min}q_{max},并根据该范围将量化损失带入到前向推理和后向参数更新的过程中);

4)导出量化网络:获取q_{min}q_{max},并计算量化参数,将量化参数s和z代入到量化公式中,转换网络中的权重为量化整数值;删除伪量化算子,在量化网络层前后分别插入量化和反量化算子。

(2) 后训练动态量化

Post training dynamic quantization是在浮点模型训练收敛之后进行量化操作,weight被提前量化,activation在前向推理过程中被动态量化(即每次都要根据实际运算的浮点数据范围每一层计算1次scale和zero_point,然后进行量化)。

在量化激活值时会以校准数据集为输入,执行推理流程然后统计每层激活值的数据分布并得到相应的量化参数,具体操作流程如下:

1)使用直方图统计的方式得到原始float32数据的统计分布P_{f}

2)在给定的搜索空间中选取若干个q_{min}q_{max}分别对激活值进行量化,得到量化后的数据Q_{q}

3)使用直方图统计得到Q_{q}的统计分布;

4)计算每个Q_{q}P_{f}的统计分布差异,并找到差异性最低的1个对应的q_{min}q_{max}来计算相应的量化参数;常用的用于度量分布差异的指标包括KL散度、对称KL散度和JS散度。

(3) 后训练静态量化

activation会基于之前校准过程中记录下的固定的scale和zero_point进行量化,整个过程不存在量化参数(scale,zero_point)的再计算。

2.代码实例

(1) 加载数据

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
import torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert

# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

(2) 构建量化网络

class QuantizedCNN(nn.Module):
    def __init__(self):
        super(QuantizedCNN, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        # x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')

(3) 量化训练并保存模型

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    # 切换到评估模式进行测试
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))

    # 在最后1个epoch后完成量化
    if epoch == num_epochs - 1:
        model_quantized = convert(model.eval(), inplace=True)
        print("Model quantization completed.")
        # 保存量化模型
        torch.save(model_quantized.state_dict(), 'quantized_model.pth')

(4) 模型测试

def test_quantized_model(model, dataloader, device='cpu'):
    model = convert(model.eval(), inplace=True)
    model.to(device) 
    correct = 0
    total = 0

    with torch.no_grad(): 
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)  
            outputs = model(data)  
            _, predicted = torch.max(outputs.data, 1)  
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')

# 测试模型
quantized_model=QuantizedCNN()
quantized_model.load_state_dict(torch.load('quantized_model.pth'))
test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

二.剪枝

剪枝是指去除模型参数中冗余或不重要的部分,可以高效地生成规模更小、内存利用率更高、能耗更低、推断速度更快的模型。

1.基本原理

根据剪枝流程的位置,可以将剪枝操作分为2种:训练时剪枝和后剪枝。

(1) 训练时剪枝

和训练时使用dropout操作较为类似,训练时剪枝会根据当前模型的结果,删除不重要的结构,固化模型再进行训练,以后续的训练来弥补部分结构剪枝带来的不利影响。

(2) 后剪枝

在模型训练完成后,根据模型权重参数和剪枝测试选取需要剪枝的部分。

2.代码实例

(1) 加载预训练模型

import torch
import torchvision.models as models

# 加载预训练的ResNet18模型
model = models.resnet18(pretrained=True)

(2) 定义剪枝算法

from torch.nn.utils.prune import global_unstructured

# 定义剪枝比例
pruning_rate = 0.5

# 对全连接层进行剪枝
def prune_model(model, pruning_rate):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            global_unstructured(module, pruning_dim=0, amount=pruning_rate)

(3)执行剪枝操作

prune_model(model, pruning_rate)

# 查看剪枝后的模型结构
print(model)

(4) 重新训练和微调

剪枝后的模型需要重新进行训练和微调,以保证模型的准确性和性能。

# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

(5) 性能测试

三.蒸馏

蒸馏是指将知识从大模型(教师模型)向小模型(学生模型)传输的过程,可以用于模型压缩和训练加速。核心组件包括:知识(knowledge)、蒸馏算法(distillation algorithm)、教师学生架构(teacher-student architecture)。

1.基本原理

蒸馏的知识的形式可以是:激活、神经元、中间层特征、教师网络参数等。可将其归类为以下3种类型。

(1) Feature-Based Knowledge

基于特征的知识蒸馏引入中间层表征,教师网络的中间层作为学生网络对应层的提示(Hints层),从而提升学生网络模型的性能。核心是期望学生能够直接模仿教师网络的特征激活值。

(2) Relation-Based Knowledege

基于关系的知识蒸馏可以分为不同层之间的关系建模和不同样本之间的关系建模2种。

•不同层之间的关系建模

通常可以建模为:

其中,f _{t}f_{s}表示学生网络内成对的特征图,\Psi _{t}\Psi _{s}是相似度函数,L_{​{R^{1}}}代表教师网络与学生网络的关联函数。

•不同样本之间的关系建模

建模如下:

其中,F _{t}F _{s}分别是teacher和student模型的特征表示;\left ( t_{i}, t_{j}\right )\in F _{t}\left ( s_{i}, s_{j}\right )\in F _{s}

基于关系的知识蒸馏的具体算法如下表所示。

(3) Response-Based Knowleddge

基于响应的知识蒸馏里响应一般指的是神经元的响应,即教师模型的最后1层逻辑输出。核心想法是让学生模型模仿教师网络的输出。

响应知识的loss:

Hinton提出的KD是将teacher的logits层作为soft label:

T是用于控制soft target重要程度的超参数。

整体蒸馏loss可以写作:

2.代码实例

(1) 加载数据

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

(2) 构建teacher 、student模型结构

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

(3) 训练模型

# 1.Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=3)
teacher.evaluate(x_test, y_test)

# 2.Train student as  usual
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate student on data
student_scratch.fit(x_train, y_train, epochs=3)
student_scratch.evaluate(x_test, y_test)

(4) 构建蒸馏模型

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data
        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)

            # Compute scaled distillation loss
            distillation_loss = (
                self.distillation_loss_fn(
                    tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                    tf.nn.softmax(student_predictions / self.temperature, axis=1),
                )
                * self.temperature**2
            )

            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data
        # Compute predictions
        y_prediction = self.student(x, training=False)
        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)
        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)
        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

(5)蒸馏

# Train student as doen usually
student_scratch.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate student trained from scratch.
student_scratch.fit(x_train, y_train, epochs=1)
student_scratch.evaluate(x_test, y_test)

四.参考

(1) Knowledge Distillation: A Survey

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/939784.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android Stduio 2024版本设置前进和后退按钮显示在主界面

Android Studio 2024(Ladybug)安装后发现前进和后退按钮不显示在主界面的工具栏,且以前在View中设置的办法无效: Android Studio 2024(Ladybug)的设置方式: File->Settings->Appearance&…

MySQL数据库——门诊管理系统数据库数据表

门诊系统数据库his 使用图形化工具或SQL语句在简明门诊管理系统数据库his中创建数据表,数据表结构见表2-3-9~表2-3-15所示。 表2-3-9 department(科室信息表) 字段名称 数据类型 长度 是否为空 说明 dep_ID int 否 科室…

Ubuntu上如何部署Nginx?

环境: Unbuntu 22.04 问题描述: Ubuntu上如何部署Nginx? 解决方案: 在Ubuntu上部署Nginx是一个相对简单的过程,以下是详细的步骤指南。我们将涵盖安装Nginx、启动服务、配置防火墙以及验证安装是否成功。 1. 更新…

【从零开始入门unity游戏开发之——C#篇08】逻辑运算符、位运算符

文章目录 一、逻辑运算符1、**&&(逻辑与)**语法:示例: 2、**||(逻辑或)**语法:示例: 3、**!(逻辑非)**语法:示例: 4、**^&…

【Android开发】安装Android Studio(2023.1.1)

下载安装包 Android Studio2023.1.1百度云盘下载,提取码:6666https://pan.baidu.com/s/1vNJezi7aDOP0poPADcBZZg?pwd6666 安装Android Studio 2023.1.1 双击下载好的安装包 弹出界面点击下一步 继续点击【Next】 更改安装路径后继续点击【Next】 点…

.net winform 实现CSS3.0 泼墨画效果

效果图 代码 private unsafe void BlendImages1(Bitmap img1, Bitmap img2) {// 确定两个图像的重叠区域Rectangle rect new Rectangle(0, 0,Math.Min(img1.Width, img2.Width),Math.Min(img1.Height, img2.Height));// 创建输出图像,尺寸为重叠区域大小Bitmap b…

Linux下部署MySQL8.0集群 - 主从复制(一主两从)

目录 一、部署前准备 1、查看系统信息 # 查看系统版本 cat /etc/red* # 查看系统位数 getconf LONG_BIT[rootlocalhost ~]# cat /etc/red* CentOS Linux release 7.5.1804 (Core) [rootlocalhost ~]# getconf LONG_BIT 642、下载对应安装包 进入MySQL官网:https:…

编辑, 抽成组件

问题 错误思路: 1 dept不能修改, 用watch监听一下:赋值给新的变量进行修改, 问题: currentDept 发生改变, depth也发生了改变,因为是浅拷贝, 用了json.pase(json.stringify(value…

2009 ~ 2019 年 408【计算机网络】大题解析

2009 年 路由算法(9’) 讲解视频推荐:【BOK408真题讲解-2009年(催更就退网版)】 某网络拓扑如下图所示,路由器 R1 通过接口 E1 、E2 分别连接局域网 1 、局域网 2 ,通过接口 L0 连接路由器 R2 &…

MySQL追梦旅途之慢查询分析建议

一、找到慢查询 查询是否开启慢查询记录 show variables like "%slow%";log_slow_admin_statements: 决定是否将慢管理语句(如 ALTER TABLE 等)记录到慢查询日志中。 log_slow_extra : MySQL 和 MariaDB 中的一个系…

进阶版 -- 某恋爱话术 app 的爬虫经历与思考(含脚本)

背景 承接前文,由于上一个app 爬出来的数据只有 1w 多条,感觉不是很过瘾 所以这次又找到了一个非破解版 app,数据量大概有 40w,安全等级直线上升 声明 本次爬虫是学习实践行为,获取到的数据均已在 24 小时内全部删…

深入理解 Linux 内核启动流程

目录 一、BIOS 与 Bootloader 1.BIOS(Basic Input/Output System) 2.Bootloader(引导加载程序) 二、内核初始化 1.解压内核映像 2.初始化硬件设备 3.建立内存管理系统 4.启动第一个进程(init) 三、…

Android笔记【19】

具体示例 run: val result someObject.run {// 这里可以使用 thisthis.someMethod() }let: val result someObject?.let {// 这里使用 itit.someMethod() }with: val result with(someObject) {// 这里使用 thissomeMethod() }apply: val obj SomeClass().apply {// 这里使…

【Qt】qt安装

在工作一年之后,还是想做一个Qt的教程,遥想研一刚刚接触Qt,从0到1学习,没有什么参考书籍,网上的资料也不多,幸好Qt官方文档写得好,加上自己肯研究,才堪堪入门。 现在我想自己写一个…

Word使用分隔符实现页面部分分栏

文章目录 Word使用分隔符实现页面部分分栏分隔符使用页面设置 Word使用分隔符实现页面部分分栏 分隔符使用 word中的分隔符: 前面不分栏,后面分栏(或前面分栏,后面不分栏),只需要在分隔位置处插入分隔符:“连续”即…

搭建Tomcat(四)---Servlet容器

目录 引入 Servlet容器 一、优化MyTomcat ①先将MyTomcat的main函数搬过来: ②将getClass()函数搬过来 ③创建容器 ④连接ServletConfigMapping和MyTomcat 连接: ⑤完整的ServletConfigMapping和MyTomcat方法: a.ServletConfigMappin…

谁说C比C++快?

看到这个问题,我我得说:这事儿没有那么简单。 1. 先把最大的误区打破 "C永远比C快" —— 某位1990年代的程序员 这种说法就像"自行车永远比汽车省油"一样荒谬。我们来看个例子: // C风格 char* str (char*)malloc(100…

html <a>设置发送邮件链接、打电话链接 <a href=“mailto:></a> <a href=“tel:></a>

1.代码 <ul><li>电话&#xff1a;<a href"tel:18888888888">188-8888-8888</a></li><li>邮箱&#xff1a;<a href"mailto:10000qq.com">10000qq.com</a></li><li>邮箱&#xff1a;<a hre…

Nginx三种安装方式

Nginx安装 可以登录 Nginx 的官方网站&#xff1a;https://www.nginx.com/ 找到安装方式。 查看如何安装开源的版本&#xff1a;https://docs.nginx.com/nginx/admin-guide/installing-nginx/installing-nginx-open-source/ 通过官方的说明&#xff0c;也可以知道安装&#…

Android 10 Launcher3 删除谷歌搜索

命令行获取页面 手机处于launcher首页 adb shell dumpsys window | findstr mCurrentFocus 输出 mCurrentFocusWindow{9afb34d u0 com.android.launcher3/com.android.launcher3.Launcher} 找到源码路径 packages/apps/Launcher3/ Android10源码 搜索控件 grep -r -n Apps…