使用 BERT 进行文本分类 (02/3)

一、说明

        在使用BERT(1)进行文本分类中,我向您展示了一个BERT如何标记文本的示例。在下面的文章中,让我们更深入地研究是否可以使用 BERT 来预测文本是使用 PyTorch 传达积极还是消极的情绪。首先,我们需要准备数据,以便使用 PyTorch 框架进行分析。

二、什么是 PyTorch

        PyTorch 是用于构建深度学习模型的框架,深度学习模型是一种机器学习,通常用于图像识别和语言处理等应用程序。它由Facebook的人工智能研究小组于2016年开发,由于其灵活性,易用性和动态计算图构建而广受欢迎。

        PyTorch 提供了一个基于 Python 的科学计算包,它使用图形处理单元 (GPU) 的强大功能来加速张量运算的计算。它具有简单直观的API,允许开发人员快速构建和训练深度学习模型。PyTorch 还支持自动微分,使用户能够计算任意函数的梯度。

三、准备我们的数据集

        首先,让我们从Github下载我们的数据。这里有一个关于如何从Github下载CSV文件的小提醒。只需继续并单击以下链接:

github.com

        然后,右键单击“原始”,然后左键单击“将链接文件下载为...”。您将看到“垃圾邮件.csv”并下载它。下载后,将其保存到您的首选文件夹中以供以后使用。

        现在,让我们导入数据。我们看到一条错误消息,告诉我们部分数据未采用 UTF-8 编码。

import pandas as pd
df = pd.read_csv("spam.csv")

ERROR: 
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 606-607: invalid continuation byte

我们可以通过了解数据包含的字符编码并在读取数据时调用该编码来修复此错误。

# Use chardet to know the character encoding 
import chardet
with open("spam.csv", 'rb') as rawdata:
    result = chardet.detect(rawdata.read(100000))
result

Output: 
{'encoding': 'Windows-1252', 'confidence': 0.7270322499829184, 'language': ''}

似乎我们的数据是在“Windows-1252”中编码的。那让我们再读一遍。它奏效了!

df = pd.read_csv("spam.csv", encoding = 'Windows-1252')
df.head()

        如我们所见,我们实际上并不需要“v1”和“v2”以外的列。此外,如果我们将“v1”和“v2”重命名为“类别”和“消息”,则更容易理解。

df = df.loc[:, ['v1', 'v2']]
df = df.rename(columns={'v1': 'Category', 'v2': 'Message'})
df.head()

        现在,我们应该看看我们的数据集,看看每个类别中有多少条消息。

df['Category'].value_counts()

Output: 
ham     4825
spam     747
Name: Category, dtype: int64

四、创建平衡数据集

        事实证明,正常邮件比垃圾邮件多。构建机器学习模型时,如果数据集不平衡,其中一个类中的数据数量明显多于另一个类,则可能会对模型的性能产生各种影响。一些潜在的后果。例如:

-1 有偏差模型:如果数据集不平衡,模型可能会偏向多数类,而对少数类表现不佳。这是因为模型更有可能预测多数类,这将导致少数类的准确性较差。

-2 泛化不良:不平衡的数据集可能导致模型泛化不良。这是因为该模型将在不代表数据真实世界分布的数据集上进行训练,因此它可能无法很好地概括看不见的数据。

-3 评估不准确:如果使用准确性作为指标评估模型,则可能会产生误导性结果。例如,始终预测不平衡数据集中多数类的模型可能具有很高的准确性,但对少数类没有用。

-4 过拟合:由于数据点数量较多,模型可能会过度拟合多数类,从而导致测试数据的性能不佳。

为了解决这些问题,可以使用各种技术来平衡数据集,例如对少数类进行过采样,对多数类进行欠采样,或同时使用两者的组合。在这篇文章中,我将使用欠采样方法。

df_spam = df[df['Category']=='spam']
df_ham = df[df['Category']=='ham']
df_ham_downsampled = df_ham.sample(df_spam.shape[0])
df_balanced = pd.concat([df_ham_downsampled, df_spam])
df_balanced['Category'].value_counts()

Output: 
ham     747
spam    747
Name: Category, dtype: int64

五、标记数据

        当数据表示为数字而不是分类为用于训练和测试的模型时,机器学习算法在准确性和其他性能指标方面表现更好。我们需要用数值对分类值进行标签编码。在这里,我们创建了一个新列“标签”,如果邮件是垃圾邮件,我们将其标记为 1,否则为 0。

df_balanced['Label']=df_balanced['Category'].apply(lambda x: 1 if x=='spam' else 0)
df_balanced = df_balanced.reset_index(drop=True)

display(df_balanced)

由作者创建

六、训练、验证和测试数据集:谁是谁

        要记住的一件事是,当我们使用 train_test_split 库来训练模型时,我们实际上是将数据集拆分为 TRAINING 数据集和 VALIDATION 数据集,而不是 TRAINING 数据集和 TESTING 数据集。下面提醒一下这些数据集的含义。

  1. 训练集:用于构建我们的模型。我们将使用训练集来找到具有反向传播规则的“最佳”权重和偏差。在此阶段,我们通常会创建多个算法,以便在交叉验证阶段比较它们的性能。
  2. 交叉验证集:此数据集用于比较基于训练集创建的预测算法的性能。我们选择性能最佳的算法。
  3. 测试集:这是“未来”数据集。现在我们已经选择了我们喜欢的预测算法,但我们还不知道它将如何在完全看不见的真实世界数据上执行。因此,我们将我们选择的预测算法应用于我们的测试集,以查看它将如何执行,以便我们可以了解我们的算法在野外的性能。

        因此,在测试集中,我们没有数据的标签,而是使用我们的模型来预测标签。我们只能将手头的数据集拆分为训练集和验证集,因为我们还没有“未来”数据。

七、拆分为训练数据集和验证数据集

        现在我们了解了这三种类型的数据的真正含义,我们可以使用scikit-learn的train_test_split来拆分数据。

from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(df_balanced['Message'],df_balanced['Label'], stratify=df_balanced['Label'], test_size=.2)

X_train.head()

Output: 
708                      ;-) ok. I feel like john lennon.
1386    Cashbin.co.uk (Get lots of cash this weekend!)...
1492    REMINDER FROM O2: To get 2.50 pounds free call...
119     Back in brum! Thanks for putting us up and kee...
89                       Sorry, I can't help you on this.
Name: Message, dtype: object

八、总结

        我们已经学会了如何下载和拆分数据。在下一篇文章中,我们将首先对其进行标记,并使用DistilBERT训练分类器。达门·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/77108.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring中循环依赖解决方案

循环依赖 循环依赖是Spring框架中常见的问题之一,当两个或多个类相互引用对方时,就会出现循环依赖的情况。这种情况下,Spring框架无法确定哪个类应该先实例化和初始化,从而导致异常。常见的解决方法有:构造函数注入、s…

【编程二三事】ES究竟是个啥?

在最近的项目中,总是或多或少接触到了搜索的能力。而在这些项目之中,或多或少都离不开一个中间件 - ElasticSearch。 今天忙里偷闲,就来好好了解下这个中间件是用来干什么的。 ES是什么? ​ ES全称ElasticSearch,是个基于Lucen…

MySQL 主从复制遇到 1590 报错

作者通过一个主从复制过程中 1590 的错误,说明了 MySQL 8.0 在创建用户授权过程中的注意事项。 作者:王祥 爱可生 DBA 团队成员,主要负责 MySQL 故障处理和性能优化。对技术执着,为客户负责。 本文来源:原创投稿 爱可生…

python 自动化学习(四) pyppeteer 浏览器操作自动化

背景 之前我在工作中涉及到了很多地方都是重复性的页面点点点工作,又因为安全保密原则不开放接口和数据库,只有一个页面来提供点击进行操作,就想着用前面学的自动化来实现,但发现前面学的模拟操作对浏览器来说并没有那么友好&…

AI项目二:基于mediapipe的虚拟鼠标控制

若该文为原创文章,转载请注明原文出处。 一、项目介绍 由于博主太懒,mediapipe如何实现鼠标控制的原理直接忽略,最初的想法是想控制摄像头识别手指控制鼠标,达到播放电影的效果。基本上效果也是可以的。简单的说是使用mediapipe检…

uniApp引入vant2

uniApp引入vant2 1、cnpm 下载:cnpm i vantlatest-v2 -S2、main.js文件引入 import Vant from ./node_modules/vant/lib/vant;Vue.use(Vant);3.app.vue中引入vant 样式文件 import /node_modules/vant/lib/index.css;

JVM——栈和堆概述,以及有什么区别?

方法栈 方法栈并不是某一个 JVM 的内存空间,而是我们描述方法被调用过程的一个逻辑概念。 在同一个线程内,T1()调用T2(): T1()先开始,T2()后开始;T2()先结束,T1()后结束。 堆和栈概述 从英文单词角度来…

代码随想录算法训练营第三十六天 | 435. 无重叠区间,763.划分字母区间,56. 合并区间

代码随想录算法训练营第三十六天 | 435. 无重叠区间,763.划分字母区间,56. 合并区间 435. 无重叠区间:eyes:题目总结:eyes: 763.划分字母区间:eyes:题目总结:eyes: 56. 合并区间:eyes:题目总结:eyes: 435. 无重叠区间 题目链接 视频讲解 给定一个区间的…

云原生 envoy xDS 动态配置 java控制平面开发 支持restful grpc实现 EDS 动态endpoint配置

envoy xDS 动态配置 java控制平面开发 支持restful grpc 动态endpoint配置 大纲 基础概念Envoy 动态配置API配置方式动静结合的配置方式纯动态配置方式实战 基础概念 Envoy 的强大功能之一是支持动态配置,当使用动态配置时,我们不需要重新启动 Envoy…

【uni-app报错】获取用户收货地址uni.chooseAddress()报错问题

chooseAddress:fail the api need to be declared in …e requiredPrivateInf 原因: 小程序配置 / 全局配置 (qq.com) 解决: 登录小程序后台申请接口 按照流程申请即可 在项目根目录中找到 manifest.json 文件,在左侧导航栏选择源码视图&a…

Springboot整合Mybatis调用Oracle存储过程

1、配置说明 Oracel11g+springboot2.7.14+mybatis3.5.13 目标:springboot整合mybatis访问oracle中的存储过程,存储过程返回游标信息。 mybatis调用oracle中的存储过程方式 2、工程结构 3、具体实现 3.1、在Oracle中创建测试数据库表 具体数据可自行添加 create table s…

SIP网络音频模块-sip网络对讲音频模块(提供POE受电模块接口)

SIP网络音频模块-sip网络对讲音频模块(提供POE受电模块接口) SIP网络音频模块SV-2401V网络对讲音频模块(支持POE) SV-2403V网络对讲音频模块_网络语音对讲模块 网络音频模块 双向对讲 SIP广播系统 SIP网络音频模块嵌入式网络对…

YOLOv8改进后效果

数据集 自建铁路障碍数据集-包含路障,人等少数标签。其中百分之八十作为训练集,百分之二十作为测试集 第一次部署 版本:YOLOv5 训练50epoch后精度可达0.94 mAP可达0.95.此时未包含任何改进操作 第二次部署 版本:YOLOv8改进版本 首…

Mongodb基础操作

一、简介 MongoDB是一个NoSQL型的数据库,基于分布式文档型储存数据库,由C语言编写,它的特点是开源、高性能、高可用、高扩展、易部署。支持 Golang、RUBY、PYTHON、JAVA、C、PHP等多种开发语言。 二、应用场景 MongoDB适用于高并发读写、数据…

创新零售,京东重新答题?

继新一轮组织架构调整后,京东从低价到下沉动作不断。 新成立的创新零售部在京东老将闫小兵的带领下悄然完成了整合。近日,京喜拼拼已改名为京东拼拼,与七鲜、前置仓等业务共同承载起京东线上线下加速融合的梦想。 同时,拼拼的更…

FPGA: RS译码仿真过程

FPGA: RS译码仿真过程 在上一篇中记录了在FPGA中利用RS编码IP核完成信道编码的仿真过程,这篇记录利用译码IP核进行RS解码的仿真过程,带有程序和结果。 1. 开始准备 在进行解码的过程时,同时利用上一篇中的MATLAB仿真程序和编码过程&#x…

微信小程序|自定义弹窗组件

目录 引言小程序的流行和重要性自定义弹出组件作为提升用户体验和界面交互的有效方式什么是自定义弹出组件自定义弹出组件的概念弹出层组件在小程序中的作用和优势为什么需要自定义弹出组件现有的标准弹窗组件的局限性自定义弹出组件在解决这些问题上的优势最佳实践和注意事

日常BUG——Java使用Bigdecimal类型报错

😜作 者:是江迪呀✒️本文关键词:日常BUG、BUG、问题分析☀️每日 一言 :存在错误说明你在进步! 一、问题描述 直接上代码: Test public void test22() throws ParseException {System.out.p…

Linux怎样处理网络请求——彻底理解IO多路复用

常见的网络IO模型 网络 IO 模型分为四种:同步阻塞 IO、同步非阻塞IO、IO 多路复用、异步非阻塞 IO(Async IO, AIO),其中AIO为异步IO,其他都是同步IO 同步阻塞IO 同步阻塞IO:在线程处理过程中,如果涉及到IO操作&…

这场大学生竞赛中,上百支队伍与合合信息用AI共克难题

目录 0 校企联合共克难题1 北京林业大学:文档格式转换2 浙江中医药大学:个性化题库3 中南林业科技大学:交互场景挖掘4 重庆邮电大学:大模型赋能智能文档5 总结 0 校企联合共克难题 近日,中国大学生服务外包创新创业大…