简洁高效的 NLP 入门指南: 200 行实现 Bert 文本分类 (TensorFlow 版)

简洁高效的 NLP 入门指南: 200 行实现 Bert 文本分类 TensorFlow 版

  • 概述
  • NLP 的不同任务
  • Bert 概述
  • MLM 任务 (Masked Language Modeling)
    • Tokenize
    • MLM 的工作原理
    • 为什么使用 MLM
  • NSP 任务 (Next Sentence Prediction)
    • NSP 任务的工作原理
    • NSP 任务栗子
    • NSP 任务的调整和局限性
  • 安装和环境配置
    • PyTorch
    • Transformers
  • Bert 架构
    • Transformer 模型基础
    • Transformer 的两个主要组成部分
    • Transformer Encoder
    • Bert 的 TransformerEncoder 工作流程
  • 200 行实现 Bert 文本分类 (Tensorflow)

概述

在当今信息时代, 自然语言处理 (NLP, Natural Linguistic Processing) 已经称为人工智能领域的一个关键分支. NLP 的目标是使计算机能够理解, 解释和操作人类语言, 从而在各种应用中发挥作用, 如语音识别, 机器翻译, 情感分析等. 随着技术的进步, NLP 已经从简单的规则和统计方法发展到使用复杂的深度学习模型, 今天我们要来介绍的就是 Bert.

Bert 文本分类

NLP 的不同任务

NLP 的不同任务包含:

  • 文本分类 (Text Classification): 根据文本主题, 将文本分为不同的类别, 李儒新闻分类
  • 情感分析 (Sentiment Analysis): 根据文本的情感倾向, 输出一个数, 表示文本的情感强度, 例如 0~5
  • 机器翻译 (Machine Translation): 根据源语言的文本, 生成目标语言的文本, 例如 zh->en
  • 命名实体识别 (Named Entity Recognition): 将文本中的实体 (例如人名, 地名, 组织名等) 进行标注
  • 句法分析 (Parsing): 根据句子的句法结构, 将句子分解为句子成分
  • 词性标注 (Part-of-speech Tagging): 根据词的语法特征, 给词标注一个词性

今年我们主要介绍的是文本分类任务.

Bert 概述

Bert (Bidirectional Encoder Representations from Transformers) 是一种基于 Transformer 架构的的模型. 在 2018 年由 Google 提出. Bert 采用了双向训练方法, 在模型学习给定的词时, 会考虑其上下文.

Bert 的双向训练方法包括下面两个方面:

  • 模型结构: Bert 模型结构采用了双向 Transformer 编码器, 即模型可以从输入两端同时进行编码
  • 预训练任务: Bert 的预训练任务包括 MLM (Masked Language Modeling) 任务和 NSP 任务, 这两个任务都需要 Bert 模型能够从文本的两端进行推理

Bert 架构

MLM 任务 (Masked Language Modeling)

MLM (Masked Language Modeling) 任务: 在 MLM 任务重, 会在输入文本中随机屏蔽一部分单词, 然后要求 Bert 模型预测被 Masked 单词的正确值.

Tokenize

分词 (Tokenization): 将文本按词 (Word) 为单位进行分割, 并转换为数字数据.
- 常见单词, 例如数据中的人名:
- Rachel对应 token id 5586
- Chandler对应 token id 13814
- Phoebe对应 token id 18188
- 上述 token id 对应 bert 的 vocab 中, roberta 的 vocab 表在服务器上, 懒得找了
- 特殊字符:
- [CLS]: token id 101, 表示句子的开始
- [SEP]: token id 102, 表示分隔句子或文本片段
- [PAD]: token id 0, 表示填充 (Padding), 当文本为达到指定长度时, 例如 512, 会用[PAD]进行填充
- [MASK]: token id 0, 表示填充 (Padding), 当文本为达到指定长度时, 例如 512, 会用[PAD]进行填充

上述字符在 Bert & Bert-like 模型中扮演着至关重要的角色, 在不同的任务重, 这些 Token ID 都是固定的, 例如 Bert 为 30522 个.

FYI: 上面的超链接是 jieba 分词的一个简单示例.

MLM 的工作原理

在 MLM 任务重, 输入文本首先被 Tokenize (分词), 词被转换为一个个数字数据, 文本由常见单词和特殊字符组成. 在处理过程中, 模型随机选择文本中的一定比例的 token (栗如: 15%). 并将这些标记替换为一个特定的特殊标记, 如[MASK](token id 0). 模型的任务是啥预测这些 mask token 的原始值.

为什么使用 MLM

MLM 的主要目的是使模型能够更好的理解语言的上下文和语义. 在传统的语言模型 (如 N-gram, 隐马可夫模型 HMM, 循环神经网络 RNN) 训练中模型都是单向的, 即模型只能考虑单词的前面或后面的上下文. 通过 MLM, 模型被迫学习使用一个单词前后的上下文来预测这个单词, 从而获得更全面的语言理解能力.

NSP 任务 (Next Sentence Prediction)

NSP (Next Sentence Prediction) 是 Bert 模型中的一个关键组成部分. NSP 用于改善模型对句子关系的理解, 特别是在理解段落或文档中句子关系方面. 这种能力对许多 NLP 任务至关重要, 例如: 问答系统, 文本摘要, 对话系统等.

NSP 任务的工作原理

在 NSP 任务重, 模型被训练来预测两个句子是否在原始文本中相邻. 这个过程涉及对句子间和语义关系的深入理解. 个栗子: A & B 俩句子, 模型需要判断 B 是否是紧跟在 A 后面的下一句. 在 Training 过冲中, Half time B 确实是 A 的下一句, 另一半时间 B 则是从语料库中随机选取的与 A 无关的句子. NSP 就是基于这些句子判断他们是否是连续的, 强迫模型学习识别句子的连贯性和上下文关系.

NSP 任务栗子

连续:
- 句子 A: “我是小白呀今年才 18 岁”
- 句子 B: “真年轻”
- NSP: 连续, B 是对 A 的回应 (年龄), 表达了作者 “我” 十分年轻

不连续:
- 句子 A: “意大利面要拌”
- 句子 B: “42 号混凝土”
- NSP: 不连续, B 和 A 内容完全无关

NSP 任务的调整和局限性

尽管在 NSP 和 Bert 的初期奔波中被广泛使用, 但是 NSP 也存在一些局限性. NSP 任务有时可能过于简化, 无法完全捕捉复杂文本中的细微关系.

随着 NLP 模型的发展, 一些研究发现去除 NSP 对某些模型的性能影响不大, 例如: Roberta, Xlnet, 和 Deberta 等后续模型都去除了 NSP 任务. 因为这些模型的底层双向结构已经足够强大, 能欧在没有 NSP 的情况下理解句子间的复杂关系.

安装和环境配置

PyTorch

pip install pytorch

Transformers

pip install transformers

Bert 架构

Bert 架构

Transformer 模型基础

Transformer 模型在 2017 年被提出, 是一种基于注意力机制 (Attention) 的架构, 用于处理序列数据. 与之前的序列处理模型 (RNN 和 LSTM) 不同, Transformer 完全依赖于注意力机制来捕获序列的全局依赖关系, 这使得模型在处理长距离依赖时更加有效.

Transformer 的两个主要组成部分

  1. Encoder (编码器): 负责处理输入数据
  2. Decoder (解码器): 负责生成输出数据

Transformer Encoder

Bert 的核心组成部分之一是基于 Transformer 的编码器, 即 TrasnformerEncoder.

class TransformerEncoder(Layer):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        # 由多层encoder_layer组成,论文中给出,bert-base是12层,bert-large是24层,一层结构就如上图中蓝色框里的结构
        # num_layers = 12 or 24
        # LayerList称之为容器,使用方法和python里的list类似
        self.layers = LayerList([(encoder_layer if i == 0 else type(encoder_layer)(**encoder_layer._config)) for i in range(num_layers)])
        self.num_layers = num_layers

TransformerEncoder 由多个相同的层堆叠而成, 每层包含两个主要子层:

  1. 多头自注意力机制 (Multi-Head Self-Attention): 这个机制允许模型在处理每个单词时考虑到句子中的所有其他单词, 从而捕获复杂的内部依赖关系. Multi-Head 的设计使得模型能够同时从不同的表示子空间中学习信息
  2. 前馈神经网络 (Feed-Forward Neural Network): 每个注意力层后面都跟着一个简单的前馈神经网络, 这个网络对每个位置的输出进行独立处理

每个子层后面有一个残差链接 (Residual Connection) 和层归一化 (Layer Normalization). 残差连接有助于避免在深层网络中出现的梯度消失 (Vanishing Gradient) 问题, 而层归一化则有助于稳定训练过程.

Bert 的 TransformerEncoder 工作流程

  1. 输入表示: 输入文本首先被转换成词嵌入向量, 然后加上位置编码 (Positional Encoding), 以提供位置信息
  2. 通过多头自注意力 (Multi-Head Self-Attention) 层, 模型学习如何更加其他单词信息调整每个单词的表示
  3. 前馈网络: 每个位置的输出被送入前馈网络, 进一步处理每个单词的表示
  4. 重复多层处理: 过程在多个TransformersEncoder层中重复进行, 每一层都进一步增强了模型对文本的理解

200 行实现 Bert 文本分类 (Tensorflow)

"""
@Module Name: bert.py
@Author: CSDN@我是小白呀
@Date: December 14, 2023

Description:
200 行实现 Bert 文本分类
"""
"""
@Module Name: bert.py
@Author: CSDN@我是小白呀
@Date: December 14, 2023

Description:
200 行实现 Bert 文本分类
"""
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from transformers import BertTokenizer, TFBertModel
import pickle
import logging


# 超参数
EPOCHS = 20  # 迭代次数
BATCH_SIZE = 16  # 单词训练样本数目
learning_rate = 5e-6  # 学习率
INPUT_DIM = 50000
MAX_LENGTH = 512
optimizer = Adam(learning_rate=learning_rate)  # 优化器
loss = tf.keras.losses.MeanSquaredError()  # 损失
bert_tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")  # Bert的分词器
bert_model = TFBertModel.from_pretrained("bert-large-uncased")  # Bert的分词器
prefix = "_raw"
logging.basicConfig(filename='model/bert_large/training_log.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')



class TrainingLoggingCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            logging.info(f"Epoch {epoch + 1}/{EPOCHS}")
            logging.info(f"loss: {logs['loss']} - accuracy: {logs['accuracy']}")
            logging.info(f"val_loss: {logs['val_loss']} - val_accuracy: {logs['val_accuracy']}")
            logging.info(f"lr: {self.model.optimizer.lr.numpy()}")



def get_data():
    """
    读取数据
    :return: 返回分批完的训练集和测试集
    """

    # 读取
    with open('../save/raw/train.pkl', 'rb') as f:
        combined_data = pickle.load(f)
        
    X_train = combined_data['X_train']
    X_valid = combined_data['X_valid']
    y_train = combined_data['y_train']
    y_valid = combined_data['y_valid']



    # 获取input/mask
    train_input = X_train["input_ids"]
    train_mask = X_train["attention_mask"]
    train_input = np.asarray(train_input)
    train_mask = np.asarray(train_mask)

    val_input = X_valid["input_ids"]
    val_mask = X_valid["attention_mask"]
    val_input = np.asarray(val_input)
    val_mask = np.asarray(val_mask)

    return train_input, val_input, train_mask, val_mask, y_train, y_valid

def lr_schedule(epoch, lr):
    """
    学习率递减
    """
    if epoch < 1:
        return lr
    else:
        return lr * 0.9

def main():
    
    # 加载数据
    X_train_input, X_test_input, X_train_mask, X_test_mask, y_train, y_test = get_data()
    print(X_train_input[:5], X_train_input.shape)
    print(X_test_input[:5], X_test_input.shape)
    print(X_train_mask[:5], X_train_mask.shape)
    print(X_test_mask[:5], X_test_mask.shape)
    print(y_train[:5], y_train.shape)
    print(y_test[:5], y_test.shape)


    input_ids = tf.keras.Input(shape=(MAX_LENGTH,), dtype=tf.int32)
    masks = tf.keras.Input(shape=(MAX_LENGTH,), dtype=tf.int32)
    bert = bert_model([input_ids, masks])
    bert = bert[1]

    classifier = Dense(24, activation='softmax')(bert)

    # 模型
    model = Model(inputs=[input_ids, masks], outputs=classifier)
    print(model.summary())

    # 组合
    model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])

    # 保存
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        "model/bert_large/bert_large.ckpt", monitor='val_loss',
        verbose=1, save_best_only=True, mode='min',
        save_weights_only=True
    )

    # 学习率调度
    lr_scheduler = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
    
    # 早停
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',  
        patience=2, 
        verbose=1,
        restore_best_weights=True
    )

    # 训练
    model.fit([X_train_input, X_train_mask], y_train, validation_data=([X_test_input, X_test_mask], y_test),
              epochs=EPOCHS, batch_size=BATCH_SIZE,
              callbacks=[checkpoint, lr_scheduler, TrainingLoggingCallback(), early_stopping])


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/249441.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8重要文件解读

&#x1f368; 本文为[&#x1f517;365天深度学习训练营学习记录博客 &#x1f366; 参考文章&#xff1a;365天深度学习训练营 &#x1f356; 原作者&#xff1a;[K同学啊 | 接辅导、项目定制] &#x1f680; 文章来源&#xff1a;[K同学的学习圈子](https://www.yuque.com/m…

js输入框部分内容不可编辑,其余正常输入,el-input和el-select输入框和多个下拉框联动后的内容不可修改

<tr>//格式// required自定义指令<e-td :required"!read" label><span>地区&#xff1a;</span></e-td><td>//v-if"!read && this.data.nationCode 148"显示逻辑<divclass"table-cell-flex"sty…

【CASS精品教程】cass11提示“请不要在虚拟机中运行此程序”的解决办法

文章目录 一、问题提示二、解决办法一、问题提示 按照正常安装教程安装好南方测绘cass 11之后,打开的时候可能会有以下提示:请不要在虚拟机中运行此程序,如下图所示: 遇到问题,咱们就想办法解决问题,下面将自己尝试的方法及最终解决情况跟大家说一下,供参考。 二、解决…

基于ssm图书商城网站的设计和开发论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本图书商城网站就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信息&am…

【JavaEE】锁的策略

作者主页&#xff1a;paper jie_博客 本文作者&#xff1a;大家好&#xff0c;我是paper jie&#xff0c;感谢你阅读本文&#xff0c;欢迎一建三连哦。 本文于《JavaEE》专栏&#xff0c;本专栏是针对于大学生&#xff0c;编程小白精心打造的。笔者用重金(时间和精力)打造&…

[css] flex wrap 九宫格布局

<div class"box"><ul class"box-inner"><li>九宫格1</li><li>九宫格2</li><li>九宫格3</li><li>九宫格4</li><li>九宫格5</li><li>九宫格6</li><li>九宫格7&l…

【AI工具】GitHub Copilot IDEA安装与使用

GitHub Copilot是一款AI编程助手&#xff0c;它可以帮助开发者编写代码&#xff0c;提供代码建议和自动完成功能。以下是GitHub Copilot在IDEA中的安装和使用步骤&#xff1a; 安装步骤&#xff1a; 打开IDEA&#xff0c;点击File -> Settings -> Plugins。在搜索框中输…

phpstudy是什么?

PHPStudy 是一个集成环境工具&#xff0c;它将 PHP 开发所需的软件&#xff0c;如 Apache&#xff08;Web服务器&#xff09;、MySQL&#xff08;数据库服务器&#xff09;、PHP&#xff08;脚本语言&#xff09;等打包在一起&#xff0c;以便用户能够轻松安装和配置这些软件&a…

【第1期】SpringSecurity基于角色和权限的细粒度接口权限控制

SpringSecurity 细粒度权限控制 一、Role 和 Authority的区别 角色用来表示某一类权限的集合&#xff0c;权限粒度更小&#xff0c;方便细粒度控制 二、创建用户、角色、权限相关表&#xff1a; CREATE TABLE common_user (id bigint(20) NOT NULL COMMENT 主键id,login_na…

详细教程 - 从零开发 鸿蒙harmonyOS应用 第四节 (鸿蒙Stage模型 登录页面 ArkTS版 推荐使用)

在鸿蒙OS中&#xff0c;Ability是应用程序提供的抽象功能&#xff0c;可以理解为一种功能。在应用程序中&#xff0c;一个页面即一种能力&#xff0c;如登录页面&#xff0c;即具有登录功能的能力。以下是对鸿蒙新建项目的登录代码功能的详细解读和工作流程的描述&#xff1a; …

人工智能在红斑狼疮应用主要以下4个方面

人工智能&#xff08;Artificial Intelligence, AI&#xff09;在医学领域的应用已取得了一定的进展。红斑狼疮&#xff08;Systemic Lupus Erythematosus, SLE&#xff09;是一种免疫系统性疾病&#xff0c;对该疾病进行诊断和治疗是一个复杂的过程。人工智能可以发挥作用&…

棒材生产线的7大智能化提升方向 蓝鹏可定制3大类

轧钢智能化控制体系&#xff0c;实行智能化轧钢&#xff0c;提高产品合格率&#xff0c;满足棒材生产线对于产品精度、生产产量、远程集中操控的需求&#xff0c;是钢厂一直致力于实现的目标&#xff0c;目前可从七大方向对棒材产线的智能化方向进行提升。 棒材生产线有以下智…

CRM客户管理系统-超详细介绍

1. CRM概述 CRM&#xff08;Customer Relationship Management&#xff09;客户关系管理&#xff0c;是一种以客户为中心&#xff0c;通过与客户建立持久的、互惠互利的合作关系&#xff0c;从而提高企业整体绩效的管理方法。CRM系统是支持CRM战略的软件工具&#xff0c;用于…

用Pyinstaller打包深度学习算法为独立的可执行程序

前言&#xff1a;随着深度学习算法的流行&#xff0c;在传统工业软件计算领域&#xff0c;传统算法逐渐被深度学习算法给代替&#xff0c;但由于基于python的深度学习算法十分依赖python环境以及例如Pytorch、Scikit-learning、Keras等机器学习库&#xff0c;将深度学习算法运用…

HarmonyOS4.0从零开始的开发教程17给您的应用添加通知

HarmonyOS&#xff08;十五&#xff09;给您的应用添加通知 通知介绍 通知旨在让用户以合适的方式及时获得有用的新消息&#xff0c;帮助用户高效地处理任务。应用可以通过通知接口发送通知消息&#xff0c;用户可以通过通知栏查看通知内容&#xff0c;也可以点击通知来打开应…

linux 查看服务启动时间

文章目录 linux 查看服务启动时间参数解析 linux 查看服务启动时间 [root104 ~]# ps -o lstart -p ps -ef |grep -v grep |grep "zookeeper"|awk {print$2}STARTED Fri Dec 15 16:54:10 2023参数解析 linux 命令中 ps -ef 详解 ps -ef表示查看全格式的进程。 ps …

UE5 PlaceActor

⚠️ 重点 PlaceActors 需在引擎初始化之后 但&#xff0c;单为这一个功能&#xff0c;更改整个模块的启动顺序&#xff0c;也不太划算 更好的办法是&#xff0c;启动顺序保持正常&#xff08;如"LoadingPhase": "Default" &#xff09;&#xff0c;然后…

超燃超欢乐!修仙喜剧动画《师兄啊师兄》第二季稳健开播

12月14日&#xff0c;备受瞩目的《师兄啊师兄》第二季终于稳健开播&#xff01;首播两集连放&#xff0c;同时第一季全13集限免&#xff0c;不仅便于新观众丝滑入坑&#xff0c;老观众也可以二刷重温&#xff0c;可以说是非常良心了&#xff01; 《师兄啊师兄》改编自人气网络小…

LeetCode(63)旋转链表【链表】【中等】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 旋转链表 1.题目 给你一个链表的头节点 head &#xff0c;旋转链表&#xff0c;将链表每个节点向右移动 k 个位置。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], k 2 输出&#xff1a;[4,5,1,2,3]示例 2&…

PE硅芯管抗紫外线和微生物侵害,对水质不会造成任何影响

PE硅芯管是一种优质的管道材料&#xff0c;具有出色的抗紫外线和微生物侵害的能力。这种管道材料采用特殊的生产工艺&#xff0c;添加了硅质材料&#xff0c;从而增强了管道的耐久性。 由于其抗紫外线性能强&#xff0c;PE硅芯管即使在户外长时间暴露于阳光下也不会出现老化、…