类BERT模型蒸馏实战

机器学习模型已经变得越来越大,以至于训练模型可能会给那些没有空闲集群的人带来痛苦。 此外,即使使用训练好的模型,当你的硬件与模型对其运行的期望不符时,推理的时间和内存成本也会飙升。 因此,为了缓解这个问题,我们并没有放弃类似 BERT 模型的深层知识,而是开发了一种称为蒸馏(distillation)的技术,将网络缩小到合理的大小,同时最大限度地减少性能损失。

如果你已经阅读了本系列的第一篇文章,那么这并不是什么新闻。 在其中,我们讨论了 DistilBERT  如何引入一种简单而有效的蒸馏技术,可以轻松应用于任何类似 BERT 的模型,但我们避开了任何具体的实现。 现在,我们将详细介绍如何将想法转化为 .py 文件。

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 

1、学生模型的初始化

由于我们想要从现有模型初始化一个新模型,因此需要访问旧模型(即教师)的权重。 我们假设预先存在的模型是在 PyTorch 上实现的 Hugging Face 模型, 因此,要获得权重,首先必须知道如何访问它们。 我们将使用 RoBERTa large 作为我们的教师模型。

1.1 Hugging Face的模型结构

我们可以尝试的第一件事是打印模型结构,这应该让我们深入了解它是如何制作的。 当然,我们总是可以深入研究 Hugging Face 文档 ,但这并不有趣。

from transformers import AutoModelForMaskedLM

roberta = AutoModelForMaskedLM.from_pretrained("roberta-large")

print(roberta)

运行此代码后,我们得到:

简单打印 RoBERTA 的第一印象

模型的结构开始出现,但我们可以让它变得更漂亮。 在 Hugging Face 模型中,我们可以使用 .children() 生成器访问模块的子组件。 因此,如果我们想要遍历整个模型,我们需要在其上调用 .children() ,并在每个产生的子级上继续调用 .children() ,等等......这描述了一个递归函数,代码如下:

from typing import Any
from transformers import AutoModelForMaskedLM

roberta = AutoModelForMaskedLM.from_pretrained("roberta-large")

def visualize_children(
    object : Any,
    level : int = 0,
) -> None:
    """
    Prints the children of (object) and their children too, if there are any.
    Uses the current depth (level) to print things in a ordonnate manner.
    """
    print(f"{'   ' * level}{level}- {type(object).__name__}")
    try:
        for child in object.children():
            visualize_children(child, level + 1)
    except:
        pass

visualize_children(roberta)

输出结果如下:

RoBERTa 的递归预览

通过展开这棵树,看起来 RoBERTa 模型的结构与其他类似 BERT 的模型一样,如下所示:

类 BERT 模型的架构

1.2 复制教师模型的权重

我们知道,要以 DistilBERT 的方式初始化类似 BERT 的模型,我们只需要复制除 Roberta 层最深层之外的所有内容,我们省略了其中的一半。

首先,我们需要创建学生模型,其架构与教师模型相同,但隐藏层数量只有一半。
为此,我们只需要使用教师模型的配置,它是一个类似字典的对象,描述了 Hugging Face 模型的架构。 当查看 roberta.config 属性时,我们可以看到以下内容:

RoBERTa 配置

我们在这里感兴趣的是 num-hidden-layers 属性。 让我们编写一个函数来复制此配置,通过将其除以 2 来更改该属性,并使用新配置创建一个新模型:

from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfig

def distill_roberta(
    teacher_model : RobertaPreTrainedModel,
) -> RobertaPreTrainedModel:
    """
    Distilates a RoBERTa (teacher_model) like would DistilBERT for a BERT model.
    The student model has the same configuration, except for the number of hidden layers, which is // by 2.
    The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0.
    The head of the teacher is also copied.
    """
    # Get teacher configuration as a dictionnary
    configuration = teacher_model.config.to_dict()
    # Half the number of hidden layer
    configuration['num_hidden_layers'] //= 2
    # Convert the dictionnary to the student configuration
    configuration = RobertaConfig.from_dict(configuration)
    # Create uninitialized student model
    student_model = type(teacher_model)(configuration)
    # Initialize the student's weights
    distill_roberta_weights(teacher=teacher_model, student=student_model)
    # Return the student model
    return student_model

当然,这个函数引入了一个缺失的部分 distill_roberta_weights ,该函数会将教师模型权重的一半置于学生层中,但我们仍然需要对其进行编码。 由于递归对于探索教师模型效果很好,因此我们可以使用相同的想法来探索和复制其中的部分内容。 我们将同时浏览教师模型和学生模型,同时将部分内容从一个模型复制到另一个模型。 唯一的技巧是要小心隐藏层部分并只复制一半。实现代码如下:

from transformers.models.roberta.modeling_roberta import RobertaEncoder, RobertaModel
from torch.nn import Module

def distill_roberta_weights(
    teacher : Module,
    student : Module,
) -> None:
    """
    Recursively copies the weights of the (teacher) to the (student).
    This function is meant to be first called on a RobertaFor... model, but is then called on every children of that model recursively.
    The only part that's not fully copied is the encoder, of which only half is copied.
    """
    # If the part is an entire RoBERTa model or a RobertaFor..., unpack and iterate
    if isinstance(teacher, RobertaModel) or type(teacher).__name__.startswith('RobertaFor'):
        for teacher_part, student_part in zip(teacher.children(), student.children()):
            distill_roberta_weights(teacher_part, student_part)
    # Else if the part is an encoder, copy one out of every layer
    elif isinstance(teacher, RobertaEncoder):
            teacher_encoding_layers = [layer for layer in next(teacher.children())]
            student_encoding_layers = [layer for layer in next(student.children())]
            for i in range(len(student_encoding_layers)):
                student_encoding_layers[i].load_state_dict(teacher_encoding_layers[2*i].state_dict())
    # Else the part is a head or something else, copy the state_dict
    else:
        student.load_state_dict(teacher.state_dict())

该函数通过递归和类型检查,确保学生模型与教师模型相同,对于 Roberta 层来说是安全的。 可以注意到,如果我们想在初始化教师模型时更改复制哪些层,则只有编码器部分中的 for 循环需要更改。

现在我们有了学生模型,我们需要训练它。 除了要使用的损失函数之外,这部分相对简单。

2、自定义损失函数

作为对 DistilBERT 训练过程的回顾,我们可以看下图:

DistilBERT 的蒸馏过程

我们将把注意力转向那个写着 LOSS 的红色大盒子。 但在揭示里面有什么之前,我们需要知道如何收集我们要喂它的东西。 从这张图中,我们可以看到我们需要三样东西:标签、学生模型和教师模型的嵌入。 标签,我们已经有了,否则,我们可能会遇到更大的问题。 现在让我们得到另外两个。

2.1 检索教师和学生的输入

在这里,我们将坚持我们的示例并使用带有分类头的 RoBERTa 来说明这部分。 我们需要的是一个函数,给定类似 BERT 模型的输入,即两个张量( input_ids 和  Attention_mask)以及模型本身,将返回该模型的输出 logits

由于我们使用的是 Hugging Face,所以这非常简单,我们唯一需要的知识就是看哪里。

from torch import Tensor

def get_logits(
    model : RobertaPreTrainedModel, 
    input_ids : Tensor,
    attention_mask : Tensor,
) -> Tensor:
    """
    Given a RoBERTa (model) for classification and the couple of (input_ids) and (attention_mask),
    returns the logits corresponding to the prediction.
    """
    return model.classifier(
        model.roberta(input_ids, attention_mask)[0]
    )

我们为学生模型和老师模型都执行这个操作,第一个有梯度,第二个没有梯度。

2.2 损失函数计算

如果损失函数有点不透明,我们建议你返回第一篇文章来阅读损失函数。 但是,如果没有时间这样做,下图应该会有所帮助:

DistilBERT 的损失

我们所说的 Converging consine loss(收敛余弦损失)是用于对齐两个输入向量的常规余弦损失。 有关更多信息,请参阅该系列的第一部分。 这是代码:

import torch
from torch.nn import CrossEntropyLoss, CosineEmbeddingLoss

def distillation_loss(
    teacher_logits : Tensor,
    student_logits : Tensor,
    labels : Tensor,
    temperature : float = 1.0,
) -> Tensor:
    """
    The distillation loss for distilating a BERT-like model.
    The loss takes the (teacher_logits), (student_logits) and (labels) for various losses.
    The (temperature) can be given, otherwise it's set to 1 by default.
    """
    # Temperature and sotfmax
    student_logits, teacher_logits = (student_logits / temperature).softmax(1), (teacher_logits / temperature).softmax(1)
    # Classification loss (problem-specific loss)
    loss = CrossEntropyLoss()(student_logits, labels)
    # CrossEntropy teacher-student loss
    loss = loss + CrossEntropyLoss()(student_logits, teacher_logits)
    # Cosine loss
    loss = loss + CosineEmbeddingLoss()(teacher_logits, student_logits, torch.ones(teacher_logits.size()[0]))
    # Average the loss and return it
    loss = loss / 3
    return loss

3、更优雅的实现

我希望你不会对 Python 是一种面向对象的编程语言感到震惊。 因此,由于所有这些函数都使用几乎相同的对象,因此不让它们成为类的一部分似乎很奇怪。 如果你想实现这一点,我建议使用 Distillator 类来整理代码,就像这个gist 。 我们不会嵌入这个,因为它很长。

当然,缺少一些东西,比如 GPU 支持、整个训练例程等。但是 DistilBERT 的所有关键思想都可以在那里找到。

4、蒸馏结果

那么以这种方式提炼出来的模型最终表现如何呢? 对于DistilBERT,可以阅读原论文。 对于 RoBERTa,Hugging Face 上已经存在类似 DistilBERT 的精简版本,就在这里。 在 GLUE 基准测试 上,我们可以比较这两个模型:

RoBERTa 与 DistilRoBERTa的对比

至于时间和内存成本,该模型的大小大约是 roberta-base 的三分之二,速度是 roberta-base 的两倍。

5、结束语

通过本系列文章,你应该拥有足够的知识来提炼遇到的任何类似 BERT 的模型。 但为什么要停在那里呢? 大自然充满了蒸馏方法,例如 TinyBERT  或 MobileBERT。 如果你认为其中一个更适合你的需求,那么应该阅读这些文章。 谁知道呢,你可能想尝试一种新的蒸馏方法,因为这是一个日益发展的领域。


原文链接:类 BERT 模型蒸馏实战 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/159470.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue-组件通信(二)

​🌈个人主页:前端青山 🔥系列专栏:Vue篇 🔖人终将被年少不可得之物困其一生 依旧青山,本期给大家带来vue篇专栏内容:vue-组件通信(二) 目录 组件通信(二) (1) props / $emit 1. 父组件向子组…

赋能汽车企业数智化转型,鼎捷软件受邀出席“中国工业软件大会”

由中国国际智能产业博览会组委会、工业和信息化部、重庆市人民政府主办的“第三届中国工业软件大会”在重庆盛大召开。工业软件主管部委及政府部门、产业上下游企业代表和业内大咖、科教领域专家学者等800余位嘉宾代表齐聚,为加快制造业数字化转型和高质量发展建言献…

Echarts 实现两两柱图重叠(背景和实际值柱图)

Echarts实现两两重叠柱状图_echarts 重叠柱状图_Web_阿凯的博客-CSDN博客 引用启发的博客 先来效果: option {backgroundColor: #03213D,animation: true, // 控制动画是否开启animationDuration: 1000, // 动画的时长, 它是以毫秒为单位animationDuration: func…

【Proteus仿真】【Arduino单片机】DS1302时钟

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器,使用PCF8574、LCD1602液晶、DS1302等。 主要功能: 系统运行后,LCD1602显示时间日期。 二、软件设计 /* 作者:…

linux课程第一课------命令的简单的介绍

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…

23111703[含文档+PPT+源码等]计算机毕业设计javaweb商城项目全套电商购物系统

文章目录 **软件开发环境及开发工具:****项目功能介绍:****论文截图:****实现****代码片段:** 编程技术交流、源码分享、模板分享、网课教程 🐧裙:776871563 软件开发环境及开发工具: 前端使用…

深入了解百度爬虫工作原理

在当今数字化时代,互联网已经成为人们获取信息的主要渠道之一。而搜索引擎作为互联网上最重要的工具之一,扮演着连接用户与海量信息的桥梁角色。然而,我们是否曾经好奇过当我们在搜索引擎中输入关键词并点击搜索按钮后,究竟是如何…

Springboot更新用户头像

人们通常(为徒省事)把一个包含了修改后userName的完整userInfo对象传给后端,做完整更新。但仔细想想,这种做法感觉有点二,而且浪费带宽。 于是patch诞生,只传一个userName到指定资源去,表示该请求是一个局部更新&#…

element el-date-picker报错Prop being mutated:“placement“快速解决方式

报错信息 Avoid mutating a prop directly since the value will be overwritten whenever the parent component re-renders. Instead, use a data or computed property based on the prop’s value. Prop being mutated: “placement” 报错版本 element-ui 2.15.6 和 2.15…

鸿蒙4.0开发笔记之DevEco Studio启动时不直接打开原项目

1、想要在DevEco Studio启动时不直接打开关闭前的那个项目,可以在设置中进行。 有两个位置可以进入“设置”,一个是左上角的File>Settings,二是右上方的设置图标。 2、进入Settings界面以后,选择Appearance&Behavior下面…

Verilog基础:仿真时x信号的产生和x信号对于各运算符的特性

相关阅读 Verilog基础https://blog.csdn.net/weixin_45791458/category_12263729.html?spm1001.2014.3001.5482 信号爆x也许是所有IC人的噩梦,满屏的红色波形常让人头疼不已,但x信号的产生原因却常常只有几种,只要遵循一定的代码规范&#…

windows安装wsl2以及ubuntu

查看自己系统的版本 必须运行 Windows 10 版本 2004 及更高版本(内部版本 19041 及更高版本)或 Windows 11 才能使用以下命令 在设置,系统里面就能看到 开启windows功能 直接winQ搜 开启hyber-V、使用于Linux的Windows子系统、虚拟机平…

无损音频播放软件 Colibri mac中文版特点介绍

Colibri for mac是一款轻量级的音频播放器软件。它具有简洁的界面设计和快速启动速度,能够提供流畅的音频播放体验。Colibri支持多种常见的音频格式,包括MP3、FLAC、ALAC、AAC等。它还提供了一些实用的功能,如音频均衡器、音频增益控制、播放…

场景交互与场景漫游-osgGA库(5)

osgGA库 osgGA库是OSG的一个附加的工具库,它为用户提供各种事件处理及操作处理。通过osgGA库读者可以像控制Windows窗口一样来处理各种事件 osgGA的事件处理器主要由两大部分组成,即事件适配器和动作适配器。osgGA:GUIEventHandler类主要提供了窗口系统的…

Leetcode -2

Leetcode Leetcode -234.回文链表Leetcode -160.相交链表 Leetcode -234.回文链表 题目:给你一个单链表的头节点 head ,请你判断该链表是否为回文链表。如果是,返回 true ;否则,返回 false 。 示例 1: 输…

【计算思维】少儿编程蓝桥杯青少组计算思维题考试真题及解析D

STEMA考试-计算思维-U8级(样题) 21.下面哪个图形与其它图形不同?( ) A. B. C. D. 22.下列哪个选项是由下图旋转得到的?( ) A. B. C. D. 23.下面哪个图形是用4个 拼成的?( &#xf…

【Kingbase FlySync】界面化管控平台:2.配置数据库同步之KES>KES

【Kingbase FlySync】界面化管控平台:3.配置数据库同步之KES->KES 部署KES数据库到KES数据库同步服务1.登录KFS管理平台2.开始配置数据节点信息(1)配置node1数据节点(2)配置node2数据节点 3.KFS拓扑图配置4.开始部署5.启动同步程序并查验是否运行正常 测试同步1.从node1数据…

Android——gradle插件配置方式——dependencies和plugins

引言 我们知道Android studio 需要gradle插件进行构建和编译,随着AGP的升级,引入gradle插件也发生了变化。旧版本通过build.gradle文件中dependencies代码块引入,新版本通过plugins代码块引入 一、旧版本引入方式dependencies 二、新版本引入…

电子学会C/C++编程等级考试2021年06月(一级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:数的输入和输出 输入一个整数和双精度浮点数,先将浮点数保留2位小数输出,然后输出整数。 时间限制:1000 内存限制:65536输入 一行两个数,分别为整数N(不超过整型范围),双精度浮点数F,以一个空格分开。输出 一行两个数,分…

Day33力扣打卡

打卡记录 最大和查询(排序单调栈上二分) 链接 大佬的题解 class Solution:def maximumSumQueries(self, nums1: List[int], nums2: List[int], queries: List[List[int]]) -> List[int]:ans [-1] * len(queries)a sorted(((a, b) for a, b in zi…