二、模型训练与优化(4):模型优化-实操

下面我将以 MNIST 手写数字识别模型为例,从 剪枝 (Pruning)量化 (Quantization) 两个常用方法出发,提供一套可实际动手操作的模型优化流程。此示例基于 TensorFlow/Keras 环境,示范如何先训练一个基础模型,然后对其进行剪枝和量化,最后验证优化后的模型性能。


目录

  1. 整体流程概览
  2. 模型剪枝 (Pruning)
    1. 安装依赖库
    2. 修改训练脚本实现剪枝
    3. 如何运行剪枝脚本
    4. 检查与验证剪枝后模型
  3. 模型量化 (Quantization)
    1. 原理与应用场景
    2. 在脚本中添加量化步骤
    3. 运行量化脚本
    4. 验证量化后模型
  4. 常见问题与建议
  5. 总结

1. 整体流程概览

在之前博客中已经可以训练一个基础 MNIST 模型(train_mnist.py)并成功获得 mnist_model.h5 的前提下,通常会按照以下顺序进行优化:

在模型训练好后,可以在mnist_project文件夹下找到mnist_model.h5,如下:

  1. 剪枝 (Pruning):减小模型大小、去除不重要的权重,生成 pruned_mnist_model.h5
  2. (可选)量化 (Quantization):将浮点模型转化为 INT8 等低比特模型,大幅减小模型体积,并提升推理速度,生成 mnist_model_quant.tflite

在此过程中,我们需要:

  • 修改已有脚本新增脚本来执行剪枝和量化的操作。
  • 确保虚拟环境已安装必要库tensorflow-model-optimizationtensorflow-lite等)。
  • 反复验证模型的大小、推理速度、准确率,找到最适合部署需求的平衡点。

2. 模型剪枝 (Pruning)

2.1 安装依赖库

  • TensorFlow Model Optimization Toolkit:其中包含 tfmot.sparsity.keras 模块,可用于剪枝、量化感知训练等。

在激活的虚拟环境(tf_env 等)下,输入:

pip install tensorflow-model-optimization

如果已经安装过,可以跳过此步骤;若版本较旧,建议 pip install --upgrade tensorflow-model-optimization

2.2 修改训练脚本实现剪枝

这里给出的示例代码可放在一个新的脚本(如 prune_mnist.py),或者在原 train_mnist.py 中替换。示例如下:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot

def main():
    # 1. 加载 MNIST 数据集
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # 2. 数据预处理
    x_train = x_train.astype("float32") / 255.0
    x_test  = x_test.astype("float32") / 255.0
    x_train = x_train.reshape(-1, 28 * 28)
    x_test  = x_test.reshape(-1, 28 * 28)

    # 3. 定义剪枝参数
    pruning_params = {
        # PolynomialDecay 让剪枝率从 initial_sparsity 到 final_sparsity 逐渐增加
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,     # 初始剪枝率 (0%)
            final_sparsity=0.5,       # 最终剪枝率 (50%)
            begin_step=0,             # 剪枝开始 step
            end_step=np.ceil(len(x_train) / 64).astype(np.int32) * 5
            # end_step: 这里相当于 epochs * (训练集样本数 / batch_size)
        )
    }

    # 4. 构建剪枝后的模型
    #   - 先定义一个包含1~2层的网络
    #   - 使用 prune_low_magnitude 对最后一层进行剪枝封装
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Dense(10, activation='softmax'),
            **pruning_params
        )
    ])

    # 5. 编译模型
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # 6. 设置剪枝回调
    #    - UpdatePruningStep:在每个批次/epoch后更新剪枝进度
    #    - PruningSummaries:可选,将剪枝信息写入到指定 log_dir,配合 TensorBoard 查看
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir='logs')
    ]

    # 7. 训练模型
    #    - epochs=5 可以根据需要加大或减少
    history = model.fit(
        x_train, y_train,
        epochs=5,
        batch_size=64,
        validation_split=0.1,
        callbacks=callbacks
    )

    # 8. 模型评估
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
    print(f"\n测试集上的准确率: {test_acc:.4f}")

    # 9. 保存剪枝后的模型
    #    - 先使用 strip_pruning 去除剪枝包装器,得到最终“瘦身”模型
    final_model = tfmot.sparsity.keras.strip_pruning(model)
    final_model.save("pruned_mnist_model.h5")

    # 10. 可视化训练过程
    plot_history(history)

def plot_history(history):
    """可视化训练曲线"""
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 4))

    # 绘制准确率曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='训练准确率')
    plt.plot(epochs_range, val_acc, label='验证准确率')
    plt.legend(loc='lower right')
    plt.title('训练和验证准确率')

    # 绘制损失曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='训练损失')
    plt.plot(epochs_range, val_loss, label='验证损失')
    plt.legend(loc='upper right')
    plt.title('训练和验证损失')

    plt.show()

if __name__ == "__main__":
    main()
代码要点
  1. tfmot.sparsity.keras.PolynomialDecay:定义从 0% 到 50% 的剪枝率逐渐增加的策略。
  2. prune_low_magnitude(...):对目标层进行剪枝包装。可以只对某些关键层做剪枝,也可对网络所有层做封装。
  3. strip_pruning(...):剪枝训练完后,需要去掉剪枝相关的“假”节点,才能得到真正稀疏的权重以减小体积。

2.3 如何运行剪枝脚本

  1. 确保已经训练过一个基础模型(可选,如果想微调原模型);或者像示例这样直接在脚本里构建一个新的网络。
  2. 打开 Anaconda Prompt(或终端),激活虚拟环境
    conda activate tf_env
    
  3. 导航到脚本所在目录
    cd C:\Users\FCZ\Desktop\Projects\mnist_project
    
  4. 运行脚本
    python prune_mnist.py
    

训练过程结束后,会打印出测试集准确率,并在目录下生成 pruned_mnist_model.h5

2.4 检查与验证剪枝后模型

  1. 模型体积:相较原始不剪枝模型,pruned_mnist_model.h5 通常会更小,但因 HDF5 格式本身包含稀疏权重的表示方式,实际文件大小并不总是线性减少。关键是剪枝会让权重矩阵变得稀疏,后续可以配合特定框架(如 STM32Cube.AI)进行再处理。
  2. 准确率:可能略有降低,一般会在 0.97~0.98 附近。若下降过多,可调整 final_sparsity (如从 0.5 改为 0.3) 或增加微调 epochs。
  3. 后续可做量化:将剪枝后模型再进行量化,可实现进一步体积和推理速度的提升。

3. 模型量化 (Quantization)

3.1 原理与应用场景

  • 量化:把模型中的权重(和激活)从 float32 转化成 int8、float16 等低位格式,典型方式是使用 TensorFlow Lite 的离线量化。
  • 适用场景:需要在嵌入式或移动端部署,同时希望降低模型大小和加速推理。
  • 代价:可能带来少量精度损失。如果需要减小精度损失,可用量化感知训练(QAT)。

3.2 在脚本中添加量化步骤

当我们在 train_mnist.py 训练完基础模型后,在prune_mnist.py完成剪枝操作后,接下来完成量化操作,编写单独脚本 quantize_mnist.py:将 训练、剪枝、量化 三个步骤整合在一起

"""
quantize_mnist.py
-----------------
在同一个脚本中完成:
1. MNIST 基础模型训练
2. 剪枝 (Pruning)
3. 量化 (Quantization)

依赖:
  - tensorflow>=2.5
  - tensorflow-model-optimization
  - numpy, matplotlib (可选, 用于可视化)
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot

def load_mnist_data():
    """加载 MNIST 数据,并做基本预处理。"""
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.astype("float32") / 255.0
    x_test  = x_test.astype("float32") / 255.0

    # 展开 28x28 -> 784
    x_train = x_train.reshape(-1, 28 * 28)
    x_test  = x_test.reshape(-1, 28 * 28)
    return (x_train, y_train), (x_test, y_test)

def create_base_model():
    """构建一个简单的全连接 MNIST 模型。"""
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

def plot_history(history, title_prefix=""):
    """可视化训练曲线"""
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 4))

    # 准确率曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='训练准确率')
    plt.plot(epochs_range, val_acc, label='验证准确率')
    plt.legend(loc='lower right')
    plt.title(f'{title_prefix} 训练和验证准确率')

    # 损失曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='训练损失')
    plt.plot(epochs_range, val_loss, label='验证损失')
    plt.legend(loc='upper right')
    plt.title(f'{title_prefix} 训练和验证损失')

    plt.show()

def main():
    # =======================================
    # 1. 数据准备
    # =======================================
    (x_train, y_train), (x_test, y_test) = load_mnist_data()

    # =======================================
    # 2. 训练基线模型
    # =======================================
    print("\n--- 步骤1: 训练基线模型 ---")
    base_model = create_base_model()
    base_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    history_base = base_model.fit(
        x_train, y_train,
        epochs=5,
        batch_size=64,
        validation_split=0.1
    )

    test_loss_base, test_acc_base = base_model.evaluate(x_test, y_test, verbose=0)
    print(f"基线模型测试集准确率: {test_acc_base:.4f}")

    # 可视化基线模型训练过程
    plot_history(history_base, title_prefix="基线模型")

    # 保存基线模型
    base_model.save("mnist_model.h5")

    # =======================================
    # 3. 剪枝 (Pruning)
    # =======================================
    print("\n--- 步骤2: 剪枝模型 ---")
    # 定义剪枝参数:从0%渐增到50%的剪枝率
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,
            final_sparsity=0.5,
            begin_step=0,
            end_step=np.ceil(len(x_train) / 64).astype(np.int32) * 5
        )
    }

    # 用之前的 base_model 权重来构造可剪枝模型
    # 也可直接对 base_model 做 prune_low_magnitude,但这里分开写更清晰
    pruned_model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Dense(10, activation='softmax'),
            **pruning_params
        )
    ])

    # 把 base_model 的第一层权重复制到 pruned_model 第1层
    pruned_model.layers[0].set_weights(base_model.layers[0].get_weights())

    # 编译可剪枝模型
    pruned_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # 设置回调:更新剪枝步数 + 记录日志
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir='logs')
    ]

    history_pruned = pruned_model.fit(
        x_train, y_train,
        epochs=3,          # 可以适当增加训练轮数
        batch_size=64,
        validation_split=0.1,
        callbacks=callbacks
    )

    test_loss_pruned, test_acc_pruned = pruned_model.evaluate(x_test, y_test, verbose=0)
    print(f"剪枝后模型测试集准确率: {test_acc_pruned:.4f}")
    plot_history(history_pruned, title_prefix="剪枝模型")

    # strip_pruning: 得到真正稀疏的权重
    final_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
    final_pruned_model.save("pruned_mnist_model.h5")

    # =======================================
    # 4. 量化 (Quantization)
    # =======================================
    print("\n--- 步骤3: 量化剪枝后模型 (PTQ) ---")
    # 您也可以对 base_model 做量化,这里演示对 剪枝后的模型 做量化
    converter = tf.lite.TFLiteConverter.from_keras_model(final_pruned_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # 如需要 representative_dataset 来校准,可添加:
    # converter.representative_dataset = ...

    # 转换为 TFLite
    tflite_quant_model = converter.convert()

    # 保存量化后的 TFLite 文件
    with open('pruned_mnist_model_quant.tflite', 'wb') as f:
        f.write(tflite_quant_model)

    print("量化后的剪枝模型已保存: pruned_mnist_model_quant.tflite")

    # 如有需要,可使用 tflite interpreter 测试推理
    # 这里仅演示到生成 TFLite 文件即可

if __name__ == "__main__":
    main()
注意
  • 量化完成后,记得在 PC 或嵌入式设备上进行推理测试,查看最终精度。

3.3 运行量化脚本

  1. 依旧在 Anaconda Prompt 中激活环境conda activate tf_env
  2. 导航到脚本所在目录
  3. 执行: python quantize_mnist.py
  4. 观察输出:若无异常,脚本会提示 "量化后的模型已保存为 mnist_model_quant.tflite"
  • 结果文件

  •          mnist_model.h5:基线模型(未剪枝、未量化)。

       pruned_mnist_model.h5:剪枝后且 strip_pruning 的 Keras 模型。

       pruned_mnist_model_quant.tflite:剪枝后再量化的 TFLite 模型,通常体积最小,速度也更快(具体依赖硬件支持)。

  • 训练基线模型:训练 5 轮得到 mnist_model.h5
  • 剪枝模型:基于基线模型的权重进行剪枝,训练 3 轮得到 pruned_mnist_model.h5
  • 量化模型:将剪枝后的模型转换为 .tflite 格式,并保存为 pruned_mnist_model_quant.tflite

总结

接下来对优化后的模型进行验证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/951289.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

免费图片批量压缩工具-支持批量修改分辨率

工作需求,需要支持修改分辨率上限的同时进行图片压缩,设计此工具。 1.支持批量文件夹、子文件 2.支持最大分辨率上限(高于设定分辨率的图片,强制修改为指定分辨率,解决大图的关键) 3.自定义压缩质量&#x…

Github上传项目

写在前面: 本次博客仅仅是个人学习记录,不具备教学作用。内容整理来自网络,太多了,所以就不放来源了。 在github页面的准备: 输入标题。 往下滑,创建 创建后会跳出下面的页面 进入home就可以看到我们刚…

并发编程 之 Java内存模型(详解)

Java 内存模型(JMM,Java Memory Model)可以说是并发编程的基础,跟众所周知的Java内存区域(堆、栈、程序计数器等)并不是一个层次的划分; JMM用来屏蔽各种硬件和操作系统的内存访问差异,以实现让Java程序在各…

[QCustomPlot] 交互示例 Interaction Example

本文是官方例子的分析: Interaction Example 推荐笔记: qcustomplot使用教程–基本绘图 推荐笔记: 4.QCustomPlot使用-坐标轴常用属性 官方例子需要用到很多槽函数, 这里先一次性列举, 自行加入到qt的.h中.下面开始从简单的开始一个个分析. void qcustomplot_main_init(void); …

WPF控件Grid的布局和C1FlexGrid的多选应用

使用 Grid.Column和Grid.Row布局,将多个C1FlexGrid布局其中,使用各种事件来达到所需效果,点击复选框可以加载数据到列表,移除列表的数据,自动取消复选框等 移除复选框的要注意!!!&am…

04、Redis深入数据结构

一、简单动态字符串SDS 无论是Redis中的key还是value,其基础数据类型都是字符串。如,Hash型value的field与value的类型,List型,Set型,ZSet型value的元素的类型等都是字符串。redis没有使用传统C中的字符串而是自定义了…

生物医学信号处理--随机信号的数字特征

前言 概率密度函数完整地表现了随机变量和随机过程的统计性质。但是信号经处理后再求其概率密度函数往往较难,而且往往也并不需要完整地了解随机变量或过程的全部统计性质只要了解其某些特定方面即可。这时就可以引用几个数值来表示该变量或过程在这几方面的特征。…

LabVIEW数据库管理系统

LabVIEW数据库管理系统(DBMS)是一种集成了数据库技术与数据采集、控制系统的解决方案。通过LabVIEW的强大图形化编程环境,结合数据库的高效数据存储与管理能力,开发人员可以实现高效的数据交互、存储、查询、更新和报告生成。LabV…

合并模型带来的更好性能

研究背景与问题提出 在人工智能领域,当需要处理多个不同任务时,有多种方式来运用模型资源。其中,合并多个微调模型是一种成本效益相对较高的做法,相较于托管多个专门针对不同任务设计的模型,能节省一定成本。然而&…

Virgo:增强慢思考推理能力的多模态大语言模型

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

本地缓存:Guava Cache

这里写目录标题 一、范例二、应用场景三、加载1、CacheLoader2、Callable3、显式插入 四、过期策略1、基于容量的过期策略2、基于时间的过期策略3、基于引用的过期策略 五、显示清除六、移除监听器六、清理什么时候发生七、刷新八、支持更新锁定能力 一、范例 LoadingCache<…

Android adb shell GPU信息

Android adb shell GPU信息 先 adb shell 进入控制台。 然后&#xff1a; dumpsys | grep GLES Android adb shell命令捕获systemtrace_android 抓trace-CSDN博客文章浏览阅读2.5k次&#xff0c;点赞2次&#xff0c;收藏8次。本文介绍了如何使用adbshell命令配合perfetto工…

ElasticSearch | Elasticsearch与Kibana页面查询语句实践

关注&#xff1a;CodingTechWork 引言 在当今大数据应用中&#xff0c;Elasticsearch&#xff08;简称 ES&#xff09;以其高效的全文检索、分布式处理能力和灵活的查询语法&#xff0c;广泛应用于各类日志分析、用户行为分析以及实时数据查询等场景。通过 ES&#xff0c;用户…

RK3588平台开发系列讲解(系统篇)Linux Kconfig的语法

文章目录 一、什么是Kconfig二、config模块三、menuconfig四、menu 和 endmenu五、choice 和 endchoice六、source七、depends on八、default九、help十、逻辑表达式沉淀、分享、成长,让自己和他人都能有所收获!😄 一、什么是Kconfig Kconfig的语法及代码结构非常简单。本博…

STM32 USB组合设备 MSC CDC

STM32 USB组合设备 MSC CDC实现 教程 教程请看大佬niu_88 手把手教你使用USB的CDCMSC复合设备&#xff08;基于stm32f407&#xff09; 大佬的教程很好&#xff0c;很详细&#xff0c;我调出来了&#xff0c;代码请见我绑定的资源 注意事项 值得注意的是&#xff1a; 1、 cu…

深入学习RabbitMQ的Direct Exchange(直连交换机)

RabbitMQ作为一种高性能的消息中间件&#xff0c;在分布式系统中扮演着重要角色。它提供了多种消息传递模式&#xff0c;其中Direct Exchange&#xff08;直连交换机&#xff09;是最基础且常用的一种。本文将深入介绍Direct Exchange的原理、应用场景、配置方法以及实践案例&a…

Node.js——path(路径操作)模块

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…

【Verdi实用技巧-Part2】

Verdi实用技巧-Part2 2 Verdi实用技巧-Part22.1 Dump波形常用的task2.1.1 Frequently Used Dump Tasks2.1.2 Demo 2.2 提取波形信息小工具--FSDB Utilities2.3 Debug in Source code view2.3.1 Find Scopes By Find Scope form 2.3.2 Go to line in Souce code View2.3.3 Use B…

web-前端小实验4

实现以上图片中的内容 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>用户注册</title&…

NLP项目实战——基于Bert模型的多情感评论分类(附数据集和源码)

在当今数字化的时代&#xff0c;分析用户评论中的情感倾向对于了解产品、服务的口碑等方面有着重要意义。而基于强大的预训练语言模型如 Bert 来进行评论情感分析&#xff0c;能够取得较好的效果。 在本次项目中&#xff0c;我们将展示如何利用 Python 语言结合transformers库&…