基于BERT对中文邮件内容分类

用BERT做中文邮件内容分类

    • 项目背景与意义
    • 项目思路
    • 数据集介绍
    • 环境配置
    • 数据加载与预处理
    • 自定义数据集
    • 模型训练
      • 加载BERT预训练模型
      • 开始训练
    • 预测效果

项目背景与意义

本文是《用BERT做中文邮件内容分类》系列的第二篇,该系列项目持续更新中。系列的起源是《使用PaddleNLP识别垃圾邮件》项目,旨在解决企业面临的垃圾邮件问题,通过深度学习方法探索多语言垃圾邮件的内容、标题提取与分类识别。

在本篇文章中,我们使用PaddleNLP的BERT预训练模型,根据提取的中文邮件内容判断邮件是否为垃圾邮件。该项目的思路在于基于前一篇项目的中文邮件内容提取,在98.5%的垃圾邮件分类器基线上,通过BERT的finetune进一步提升性能。
在这里插入图片描述

项目思路

在《使用PaddleNLP识别垃圾邮件(一)》项目的基础上,我们使用BERT进行finetune,力求在LSTM的98.5%的基线上进一步提升准确率。同时,文章中详细介绍了BERT模型的原理和PaddleNLP对BERT模型的应用,读者可以参考项目PaddleNLP2.0:BERT模型的应用进行更深入的了解。

本项目参考了陆平老师的项目应用BERT模型做短文本情绪分类(PaddleNLP 2.0),但由于PaddleNLP版本迭代的原因,进行了相应的调整和说明。

数据集介绍

我们使用了TREC 2006 Spam Track Public Corpora,这是一个公开的垃圾邮件语料库,包括英文数据集(trec06p)和中文数据集(trec06c)。在本项目中,我们仅使用了TREC 2006提供的中文数据集进行演示。数据集来源于真实邮件,保留了邮件的原有格式和内容。

除了TREC 2006外,还有TREC 2005和TREC 2007的英文垃圾邮件数据集,但本项目仅使用了TREC 2006提供的中文数据集。数据集文件目录形式如下:

trec06c
│
├── data
│   │   000
│   │   001
│   │   ...
│   └───215
├── delay
│   │   index
└── full
    │   index

邮件内容样本示例:

负责人您好我是深圳金海实业有限公司...
GG非常好的朋友H在计划马上的西藏自助游...

环境配置

本项目基于Paddle 2.0编写,如果你的环境不是本版本,请先参考官网安装Paddle 2.0。以下是环境配置代码:

# 导入相关的模块
import re
import jieba
import os
import random
import paddle
import paddlenlp as ppnlp
from paddlenlp.data import Stack, Pad, Tuple
import paddle.nn.functional as F
import paddle.nn as nn
from visualdl import LogWriter
import numpy as np
from functools import partial

数据加载与预处理

项目中使用了PaddleNLP的BertTokenizer进行数据处理,该tokenizer可以将原始输入文本转化成模型可接受的输入数据格式。以下是数据加载与预处理的代码:

# 解压数据集
!tar xvf data/data89631/trec06c.tgz

# 去掉非中文字符
def clean_str(string):
    string = re.sub(r"[^\u4e00-\u9fff]", " ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip()

# 从指定路径读取邮件文件内容信息
def get_data_in_a_file(original_path, save_path='all_email.txt'):
    email = ''
    f = open(original_path, 'r', encoding='gb2312', errors='ignore')
    for line in f:
        line = line.strip().strip('\n')
        line = clean_str(line)
        email += line
    f.close()
    return email[-200:]

# 读取标签文件信息
f = open('trec06c/full/index', 'r')
for line in f:
    str_list = line.split(" ")
    if str_list[0] == 'spam':
        label = '0'
    elif str_list[0] == 'ham':
        label = '1'
    text = get_data_in_a_file('trec06c/full/' + str(str_list[1].split("\n")[0]))
    with open("all_email.txt", "a+") as f:
        f.write(text + '\t' + label + '\n')

自定义数据集

在项目中,我们需要自定义数据集,并使其数据格式与使用ppnlp.datasets.ChnSentiCorp.get_datasets加载后完全一致。以下是自定义数据集的代码:

class SelfDefinedDataset(paddle.io.Dataset):
    def __init__(self, data):
        super(SelfDefinedDataset, self).__init__()
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

    def get_labels(self):
        return ["0", "1"]

def txt_to_list(file_name):
    res_list = []
    for line in open(file_name):
        res_list.append(line.strip().split('\t'))
    return res_list

trainlst = txt_to_list('train_list.txt')
devlst = txt_to_list('eval_list.txt')
testlst = txt_to_list('test_list.txt')

train_ds, dev_ds, test_ds = SelfDefinedDataset.get_datasets([trainlst, devlst, testlst])

模型训练

加载BERT预训练模型

项目中使用了PaddleNLP提供的BertForSequenceClassification模型进行文本分类的Fine-tune。由于垃圾邮件识别是二分类问题,所以设置num_classes为2。

以下是加载BERT预训练模型的代码:

# 加载预训练模型
model = ppnlp.transformers.BertForSequenceClassification.from_pretrained("bert-base-chinese", num_classes=2)

开始训练

为了监控训练过程,引入了VisualDL记录训练log信息。以下是开始训练的代码:

# 设置训练超参数
learning_rate = 1e-5
epochs = 10
warmup_proption = 0.1
weight_decay = 0.01

num_training_steps = len(train_loader) * epochs
num_warmup_steps = int(warmup_proption * num_training_steps)

def get_lr_factor(current_step):
    if current_step < num_warmup_steps:
        return float(current_step) / float(max(1, num_warmup_steps))
    else:
        return max(0.0,
                    float(num_training_steps - current_step) /
                    float(max(1, num_training_steps - num_warmup_steps)))

# 学习率调度器
lr_scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate, lr_lambda=lambda current_step: get_lr_factor(current_step))

# 优化器
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ])

# 损失函数
criterion = paddle.nn.loss.CrossEntropyLoss()

# 评估函数
metric = paddle.metric.Accuracy()

# 训练过程
global_step = 0
with LogWriter(logdir="./log") as writer:
    for epoch in range(1, epochs + 1):    
        for step, batch in enumerate(train_loader, start=1):
            input_ids, segment_ids, labels = batch
            logits = model(input_ids, segment_ids)
            loss = criterion(logits, labels)
            probs = F.softmax(logits, axis=1)
            correct = metric.compute(probs, labels)
            metric.update(correct)
            acc = metric.accumulate()

            global_step += 1
            if global_step % 50 == 0:
                print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))
                writer.add_scalar(tag="train/loss", step=global_step, value=loss)
                writer.add_scalar(tag="train/acc", step=global_step, value=acc)
            
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
        
        eval_loss, eval_acc = evaluate(model, criterion, metric, dev_loader)
        writer.add_scalar(tag="eval/loss", step=epoch, value=eval_loss)
        writer.add_scalar(tag="eval/acc", step=epoch, value=eval_acc)

可以看到,在第2个epoch后验证集准确率已经达到99.4%以上,在第3个epoch就能达到99.6%以上。

预测效果

完成模型训练后,我们可以使用训练好的模型对测试集进行预测。以下是预测效果的代码:

data = ['您好我公司有多余的发票可以向外代开,国税,地税,运输,广告,海关缴款书如果贵公司,厂,有需要请来电洽谈,咨询联系电话,罗先生谢谢顺祝商祺']
label_map = {0: '垃圾邮件', 1: '正常邮件'}

predictions = predict(model, data, tokenizer, label_map, batch_size=32)
for idx, text in enumerate(data):
    print('预测内容: {} \n邮件标签: {}'.format(text, predictions[idx]))

预测效果良好,一个验证集准确率高达99.6%以上、基于BERT的中文邮件内容分类顺利完成!

以上是本文的全部内容,希望对读者理解如何使用BERT进行中文邮件内容分类有所帮助。欢迎交流指导!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/337854.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

采集B站up主视频信息

一、网页信息&#xff08;示例网址&#xff1a;https://space.bilibili.com/3493110839511225/video&#xff09; 二、查看响应数据 三、查看数据包内容 四、相关代码&#xff08;代码内容未进行翻页爬取&#xff09; # Time: 2024/1/19 16:42 # Author: 马龙强 # File: 采集B…

【Linux】第三十二站:命名管道

文章目录 一、命名管道介绍二、编码1.mkfifo2.unlink3.一个简单的例子4.修改 一、命名管道介绍 管道应用的一个限制就是只能在具有共同祖先&#xff08;具有亲缘关系&#xff09;的进程间通信。 如果我们想在不相关的进程之间交换数据&#xff0c;可以使用FIFO文件来做这项工作…

<软考高项备考>《论文专题 - 78 风险管理(10)》

10 论文-历年真题解析 10.1 2005年上半年真题 请围绕“项目的风险管理”论题&#xff0c;分别从以下三个方面进行论述&#xff1a; 1&#xff0e;概要叙述你参与管理过的信息系统项目&#xff08;项目的背景、发起单位、目的、项目周期、交付的产品等&#xff09;&#xff0c…

【排序算法】五、冒泡排序(C/C++)

「前言」文章内容是排序算法之冒泡排序的讲解。&#xff08;所有文章已经分类好&#xff0c;放心食用&#xff09; 「归属专栏」排序算法 「主页链接」个人主页 「笔者」枫叶先生(fy) 目录 冒泡排序1.1 原理1.2 代码实现&#xff08;C/C&#xff09;1.3 特性总结 冒泡排序 1.1…

每日一题——1295.统计位数为偶数的数字

方法一 个人方法&#xff1a; 想知道整数型数字有多少位&#xff0c;可以直接把数字转字符&#xff0c;看字符的长度就是数字的位数 var findNumbers function(nums) {let count0for(let num of nums){let strnumif(str.length%20) count}return count }; 消耗时间和内存情况…

uni-app使用HBuilderX打包Web项目

非常简单&#xff0c;就是容易忘记 一、找到manifest.json配置Web配置 二、源码视图配置 "h5" : {"template" : "","domain" : "xxx.xx.xx.xxx","publicPath" : "./","devServer" : {&quo…

【Java程序员面试专栏 专业技能篇】MySQL核心面试指引(一):基础知识考察

关于MySQL部分的核心知识进行一网打尽,包括三部分:基础知识考察、核心机制策略、性能优化策略,通过一篇文章串联面试重点,并且帮助加强日常基础知识的理解,全局思维导图如下所示 本篇Blog为第一部分:基础知识考察,子节点表示追问或同级提问 基本概念 包括一些核心问…

什么是葡萄酒“质量三级标准”?

在葡萄酒的世界里有一个笼统的级别分为&#xff1a;入门、精品和顶级。那么&#xff0c;对应这三个级别的标准都是什么呢&#xff1f; 入门级别的标准&#xff1a;入门级别的酒首先喝起来新鲜且顺口。新鲜很容易理解&#xff0c;就是没有腐熟水果的味道&#xff0c;也就是“罐…

8.3最大自序和(LC53-M)

算法&#xff1a; 如果 -2 1 在一起&#xff0c;计算起点的时候&#xff0c;一定是从 1 开始计算&#xff0c;因为负数只会拉低总和&#xff0c;这就是贪心贪的地方&#xff01; &#xff08;-21&#xff0c;起点为负数&#xff0c;加上后面的数&#xff0c;只会让和变小&…

《WebKit 技术内幕》之六(3): CSS解释器和样式布局

3 WebKit布局 3.1 基础 当WebKit创建RenderObject对象之后&#xff0c;每个对象是不知道自己的位置、大小等信息的&#xff0c;WebKit根据框模型来计算它们的位置、大小等信息的过程称为布局计算&#xff08;或者称为排版&#xff09;。 图描述了这一过程中涉及的主要WebKit…

浅谈 ret2text

文章目录 ret2text无需传参重构传参函数调用约定x86x64 ret2text ret2text就是执行程序中已有的代码&#xff0c;例如程序中写有system等系统的调用函数 无需传参 如果程序的后门函数参数已经满足 getshell 的需求&#xff0c;那么就可以直接溢出覆盖 ret 地址不用考虑传参问…

SystemC学习笔记(三) - 查看模块的波形

简述 波形在Simulation/Emulation中地位十分重要&#xff0c;尤其是在研发初期&#xff0c;只能通过波形来查看软件hang住的位置。 对于TLM来说&#xff0c;查看波形一般是指查看pvbus上的transaction&#xff0c;而对于SystemC本身来说&#xff0c;查看波形就是使用Gtkwave或…

WeChall

WeChall-Scored Challenges 一、Get Sourced二、Stegano I三、Crypto - Caesar I四、WWW-Robots五、 ASCII六、URL七、Christmas Hippety八、No DNS 网站链接&#xff1a;https://www.wechall.net/challs/Training/by/chall_score/ASC/page-1 一、Get Sourced 题目链接&#…

HTML+JavaScript-01

说明 之前有一篇JavaWeb-JavaScript中只是简单介绍了一点JavaScript的内容&#xff0c;这篇笔记算是续写的&#xff0c;但是从01开始编号。 引入js文件 html、css、js俗称前端三剑客&#xff0c;一般都是分开写&#xff0c;先写框架、再写css、最后写js。因此在工程量大的情…

HarmonyOS—配置开发环境

应用/服务支持API Version 4至9&#xff0c;首次使用DevEco Studio&#xff0c;工具的配置向导会引导您下载SDK及工具链。配置向导默认下载 API Version 9的SDK及工具链&#xff0c;如需下载API Version 4至8&#xff0c;可在工程配置完成后&#xff0c;进入HarmonyOS SDK界面手…

福建专业建筑清水模板 — 安全可靠的选择

在福建地区的建筑工程中&#xff0c;选择高质量的清水模板对于确保结构的美观和工程的安全至关重要。我们能强优品木业提供的专业建筑清水模板&#xff0c;以其卓越的质量和安全性&#xff0c;成为施工团队的首选。 产品特性 优质材料制作&#xff1a;选用高品质木材制作&…

机械设计-哈工大课程学习-螺纹连接

圆柱螺纹主要几何参数螺纹参数 ①外径&#xff08;大径&#xff09;&#xff0c;与外螺纹牙顶或内螺纹牙底相重合的假想圆柱体直径。螺纹的公称直径即大径。 ②内径&#xff08;小径&#xff09;&#xff0c;与外螺纹牙底或内螺纹牙顶相重合的假想圆柱体直径。 ③中径&#xff…

React 初次接触

背景 还是为了完善高大上的在线文档系统&#xff0c;虽然比着葫芦画瓢的修改了一些所谓的代码&#xff0c;慢慢的才发现&#xff0c;原来这就是传说中的React&#xff0c;所以有比较又要囫囵吞枣一下React。 基本原理 参照《React技术揭秘》 网上有电子版 &#xff0c;应该是…

selenium-java中切换iframe

1、当iframe中有固定的name或者id时可以通过name和id进行切换,代码如下 driver.switchTo().frame("name"); 2、当iframe中没有固定的name或者id时可以通过iframe角标进行切换&#xff0c;在浏览器通过ctrlf快捷键&#xff0c;搜索标签框输入//iframe;来查看当前ifr…

11、Kafka ------ Kafka 核心API 及 生产者API 讲解

目录 Kafka核心API 及 生产者API讲解★ Kafka的核心APIKafka包含如下5类核心API&#xff1a; ★ 生产者APIKafka 的API 文档 ★ 使用生产者API发送消息 Kafka核心API 及 生产者API讲解 官方文档 ★ Kafka的核心API Kafka包含如下5类核心API&#xff1a; Producer API&#x…