【NLP练习】使用Word2Vec实现文本分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、数据预处理

1. 任务说明

本次加入Word2Vec使用PyTorch实现中文文本分类,Word2Vec则是其中的一种词嵌入方法,是一种用于生成词向量的浅层神经网络模型。Word2Vec通过学习大量的文本数据,将每个单词表示为一个连续的向量,这些向量可以捕捉单词之间的语义和句法关系。数据示例如下:
在这里插入图片描述

2. 加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")   #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
import pandas as pd

#加载自定义中文数据
train_data = pd.read_csv(r'D:\Personal Data\Learning Data\DL Learning Data\train.csv',sep ='\t',header = None)
train_data.head()

输出:
在这里插入图片描述

#构造数据集迭代器
def coustom_data_iter(texts,labels):
 for x,y in zip(texts,labels):
  yield x,y

x = train_data[0].values[:]
y = train_data[1].values[:]

3. 构建词典

from gensim.models.word2vec import Word2Vec
import numpy as np

#训练Word2Vec浅层神经网络模型
w2v = Word2Vec(vector_size=100,   #特征向量的维度,默认为100
              min_count=3)        #对字典做截断,词频少于min_count次数的单词会被丢弃掉,默认值为5
w2v.build_vocab(x)
w2v.train(x,
         total_examples=w2v.corpus_count,
         epochs=20)

输出:

(2732920, 3663560)
#将文本转化为向量
def average_vec(text):
    vec = np.zeros(100).reshape((1,100))
    for word in text:
        try:
            vec += w2v.wv[word].reshape((1,100))
        except KeyError:
            continue
    return vec

#将词向量保存为Ndarray
x_vec = np.concatenate([average_vec(z) for z in x])

#保存Word2Vec模型及词向量
w2v.save(r'D:\Personal Data\Learning Data\DL Learning Data\w2v.pkl')
train_iter = coustom_data_iter(x_vec, y)
len(x),len(x_vec)

输出:

(12100, 12100)
label_name = list(set(train_data[1].values[:]))
print(label_name)
['Music-Play', 'Travel-Query', 'Weather-Query', 'Audio-Play', 'Radio-Listen', 'Video-Play', 'Calendar-Query', 'HomeAppliance-Control', 'Alarm-Update', 'Other', 'TVProgram-Play', 'FilmTele-Play']

4. 生成数据批次和迭代器


text_pipeline = lambda x : average_vec(x)
label_pipeline = lambda x : label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

输出:

[[ -3.16253691  -1.9659146    3.77608298   1.06067566  -5.1883576
   -8.70868033   3.89949582  -2.18139926   6.70676575  -4.99197783
   16.07808281   9.24493882 -15.24484421  -6.60270358  -6.24634131
   -3.64680131  -2.53697125   2.8301437    7.22867384  -2.13360262
    2.1341381    6.06681348  -4.65962007   1.23247945   4.33183173
    2.15399135  -1.83306327  -2.49018155  -0.22937663   1.57925591
   -3.22308699   3.56521453   5.94520254   3.46486389   3.46772102
   -4.10725167   0.31579057   9.28542571   7.48527321  -2.93014296
    8.39484799 -11.3110949    4.46019076  -0.64214947  -6.3485507
   -5.3710938    1.6277833   -1.44570495   7.21582842   3.29212736
    0.79481401  10.0952674   -0.72304608  -0.46801499   6.08651663
   -0.67166806  10.56184006   1.74745524  -4.52621601   1.8375443
   -5.368839    10.54501078  -2.85536074  -4.55352878 -13.42422374
    3.17138463   7.39386847  -2.24578104 -16.08510212  -5.7369401
   -2.90420356  -4.19321531   3.29097138  -9.36627482   3.67335742
   -0.80693699  -0.53749662  -3.67742246   0.48116201   5.51754848
    0.82724179   4.13207588   0.86254621  13.13354776  -3.11359251
    2.18450189   9.11669949  -4.88159943   2.01295654  11.02899793
   -5.33385142  -7.47531134  -4.02018939  -0.52363324  -1.79980185
    4.00845213  -2.436053     0.16959296  -7.10417359  -0.55219389]]
5
#生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list = [],[]        
    for(_text, _label) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)
        text_list.append(processed_text)

    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    return text_list.to(device), label_list.to(device)

#数据加载器
dataloader = DataLoader(
    train_iter,
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_batch
)

二、构建模型

1. 搭建模型

#搭建模型
from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, num_class):
        super(TextClassificationModel,self).__init__()
        self.fc = nn.Linear(100, num_class)

    def forward(self, text):
        return self.fc(text)

2. 初始化模型

#初始化模型
#定义实例
num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)

3. 定义训练与评估函数

#定义训练与评估函数
import time

def train(dataloader):
    model.train()          #切换为训练模式
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 50
    start_time = time.time()
    for idx, (text,label) in enumerate(dataloader):
        predicted_label = model(text)
        optimizer.zero_grad()                             #grad属性归零
        loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真
        loss.backward()                                   #反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(),0.1)  #梯度裁剪
        optimizer.step()                                  #每一步自动更新
        
        #记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(
                epoch,
                idx,
                len(dataloader),
                total_acc/total_count,
                train_loss/total_count))
            total_acc,train_loss,total_count = 0,0,0
            staet_time = time.time()

def evaluate(dataloader):
    model.eval()      #切换为测试模式
    total_acc,train_loss,total_count = 0,0,0
    with torch.no_grad():
        for idx,(text,label) in enumerate(dataloader):
            predicted_label = model(text)
            loss = criterion(predicted_label,label)   #计算loss值
            #记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count

三、训练模型

1. 拆分数据集并运行模型

#拆分数据集并运行模型
from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset

# 超参数设定
EPOCHS      = 10   #epoch
LR          = 5    #learningRate
BATCH_SIZE  = 64   #batch size for training

#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None

# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset   = to_map_style_dataset(train_iter)
split_train_, split_valid_ = random_split(train_dataset,
                                         [int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])

                                           
train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    #获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(
        epoch,
        time.time() - epoch_start_time,
        val_acc,
        val_loss))
    print('-' * 69)

输出:

|epoch1|  50/ 152 batches|train_acc0.724 train_loss0.02592
|epoch1| 100/ 152 batches|train_acc0.820 train_loss0.01937
|epoch1| 150/ 152 batches|train_acc0.832 train_loss0.01843
---------------------------------------------------------------------
| epoch 1 | time:1.11s | valid_acc 0.827 valid_loss 0.019
---------------------------------------------------------------------
|epoch2|  50/ 152 batches|train_acc0.842 train_loss0.01750
|epoch2| 100/ 152 batches|train_acc0.831 train_loss0.01787
|epoch2| 150/ 152 batches|train_acc0.841 train_loss0.01953
---------------------------------------------------------------------
| epoch 2 | time:1.14s | valid_acc 0.780 valid_loss 0.029
---------------------------------------------------------------------
|epoch3|  50/ 152 batches|train_acc0.873 train_loss0.01189
|epoch3| 100/ 152 batches|train_acc0.884 train_loss0.00944
|epoch3| 150/ 152 batches|train_acc0.905 train_loss0.00763
---------------------------------------------------------------------
| epoch 3 | time:1.09s | valid_acc 0.886 valid_loss 0.009
---------------------------------------------------------------------
|epoch4|  50/ 152 batches|train_acc0.891 train_loss0.00794
|epoch4| 100/ 152 batches|train_acc0.894 train_loss0.00711
|epoch4| 150/ 152 batches|train_acc0.905 train_loss0.00646
---------------------------------------------------------------------
| epoch 4 | time:1.09s | valid_acc 0.874 valid_loss 0.009
---------------------------------------------------------------------
|epoch5|  50/ 152 batches|train_acc0.902 train_loss0.00593
|epoch5| 100/ 152 batches|train_acc0.909 train_loss0.00591
|epoch5| 150/ 152 batches|train_acc0.897 train_loss0.00687
---------------------------------------------------------------------
| epoch 5 | time:1.03s | valid_acc 0.890 valid_loss 0.008
---------------------------------------------------------------------
|epoch6|  50/ 152 batches|train_acc0.909 train_loss0.00592
|epoch6| 100/ 152 batches|train_acc0.900 train_loss0.00609
|epoch6| 150/ 152 batches|train_acc0.904 train_loss0.00607
---------------------------------------------------------------------
| epoch 6 | time:1.02s | valid_acc 0.890 valid_loss 0.008
---------------------------------------------------------------------
|epoch7|  50/ 152 batches|train_acc0.908 train_loss0.00559
|epoch7| 100/ 152 batches|train_acc0.906 train_loss0.00604
|epoch7| 150/ 152 batches|train_acc0.902 train_loss0.00623
---------------------------------------------------------------------
| epoch 7 | time:1.00s | valid_acc 0.888 valid_loss 0.008
---------------------------------------------------------------------
|epoch8|  50/ 152 batches|train_acc0.906 train_loss0.00558
|epoch8| 100/ 152 batches|train_acc0.904 train_loss0.00592
|epoch8| 150/ 152 batches|train_acc0.908 train_loss0.00602
---------------------------------------------------------------------
| epoch 8 | time:1.08s | valid_acc 0.888 valid_loss 0.008
---------------------------------------------------------------------
|epoch9|  50/ 152 batches|train_acc0.903 train_loss0.00566
|epoch9| 100/ 152 batches|train_acc0.911 train_loss0.00550
|epoch9| 150/ 152 batches|train_acc0.904 train_loss0.00630
---------------------------------------------------------------------
| epoch 9 | time:1.20s | valid_acc 0.889 valid_loss 0.008
---------------------------------------------------------------------
|epoch10|  50/ 152 batches|train_acc0.910 train_loss0.00564
|epoch10| 100/ 152 batches|train_acc0.912 train_loss0.00550
|epoch10| 150/ 152 batches|train_acc0.897 train_loss0.00633
---------------------------------------------------------------------
| epoch 10 | time:1.09s | valid_acc 0.889 valid_loss 0.008
---------------------------------------------------------------------
test_acc,test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

输出:

模型准确率为:0.8843

2. 测试指定数据

def predict(text,text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text),dtype = torch.float32)
        print(text.shape)
        output = model(text)
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
print("该文本的类别是:%s" %label_name[predict(ex_text_str,text_pipeline)])
torch.Size([1, 100])
该文本的类别是:Travel-Query

四、总结

Word2Vec 通过学习单词的上下文关系,将单词映射到向量空间。这使得语义上相似的单词在向量空间中具有相近的位置。因此,使用 Word2Vec 可以更好地捕获文本中的语义信息,从而提高文本分类的准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/565863.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【多态】底层原理

博主首页&#xff1a; 有趣的中国人 专栏首页&#xff1a; C进阶 本篇文章主要讲解 多态底层原理 的相关内容 1. 多态原理 1.1 虚函数表 先看一下这段代码&#xff0c;计算一下sizeof(Base)是多少&#xff1a; class Base { public:virtual void Func1(){cout << &quo…

力扣--N皇后

题目: 按照国际象棋的规则&#xff0c;皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上&#xff0c;并且使皇后彼此之间不能相互攻击。 给你一个整数 n &#xff0c;返回所有不同的 n 皇后问题 的解决方案。…

Android驱动开发之如何编译和更换内核

编译内核可以使用图形化的界面配置,也可以直接使用脚本。在X86_64模拟器环境下,不用交叉编译,而交叉编译工具很容易出现兼容问题,一般也只能使用芯片厂商提供的工具,而不是GNU提供的工具。 android内核开发流程以及架构变化了很多,详情请看 内核官网 内核版本选择 由…

蓝桥杯第17169题——兽之泪II

问题描述 在蓝桥王国&#xff0c;流传着一个古老的传说&#xff1a;在怪兽谷&#xff0c;有一笔由神圣骑士留下的宝藏。 小蓝是一位年轻而勇敢的冒险家&#xff0c;他决定去寻找宝藏。根据远古卷轴的提示&#xff0c;如果要找到宝藏&#xff0c;那么需要集齐 n 滴兽之泪&#…

NFTScan | 04.15~04.21 NFT 市场热点汇总

欢迎来到由 NFT 基础设施 NFTScan 出品的 NFT 生态热点事件每周汇总。 周期&#xff1a;2024.04.15~ 2024.04.21 NFT Hot News 01/ 数据&#xff1a;Bitcoin Puppets 市值超越 Pudgy Penguins&#xff0c;现排名第五 4 月 15 日&#xff0c;据 CoinGecko 数据显示&#xff0c…

LeetCode in Python 72. Edit Distance (编辑距离)

编辑距离的基本思想很直观&#xff0c;即不断比较两个单词每个位置的元素&#xff0c;若相同则比较下一个&#xff0c;若不同则需要考虑从插入、删除、替换三种方法中选择一个最优的策略。涉及最优策略笔者最先想到的即是动态规划的思想&#xff0c;将两个单词的位置对应放在矩…

zigbee cc2530的室内/矿井等定位系统RSSI原理

1. 定位节点软件设计流程 2. 硬件设计 cc2530 最小系统 3. 上位机 c# 设计上位机&#xff0c;通过串口连接协调器节点&#xff0c;传输数据到pc上位机&#xff0c;显示节点坐标信息 4. 实物效果 需要4个节点&#xff0c;其中一个协调器&#xff0c;两个路由器作为参考节点&a…

计算机视觉 | 交通信号灯状态的检测和识别

Hi&#xff0c;大家好&#xff0c;我是半亩花海。本项目旨在使用计算机视觉技术检测交通信号灯的状态&#xff0c;主要针对红色和绿色信号灯的识别。通过分析输入图像中的像素颜色信息&#xff0c;利用OpenCV库实现对信号灯状态的检测和识别。 目录 一、项目背景 二、项目功能…

uni-app 的 扩展组件(uni-ui) 与uView UI

uni-app 的 扩展组件&#xff08;uni-ui&#xff09; 与uView UI uni-ui 官方背景&#xff1a;组件集&#xff1a;设计风格&#xff1a;文档与支持&#xff1a;社区与生态&#xff1a; uView UI 第三方框架&#xff1a;组件集&#xff1a;设计风格&#xff1a;文档与支持&#…

Python --- 新手小白自己动手安装Anaconda+Jupyter Notebook全记录(Windows平台)

新手小白自己动手安装AnacondaJupyter Notebook全记录 这两天在家学Pythonmathine learning&#xff0c;在我刚刚入手python的时候&#xff0c;我写了一篇新手的入手文章&#xff0c;是基于Vs code编译器的入手指南&#xff0c;里面包括如何安装python&#xff0c;以及如何在Vs…

四六级英语听力考试音频无线发射系统在安顺学院的成功应用分析

四六级英语听力考试音频无线发射系统在安顺学院的成功应用分析 由北京海特伟业科技任洪卓发布于2024年4月22日 安顺学院为了提高学生的外语听力水平&#xff0c;并确保英语四六级听力考试的稳定可靠进行&#xff0c;决定对传统的英语听力音频传输系统进行改造&#xff0c;以提供…

海康Visionmaster-常见问题排查方法-启动阶段

VM试用版启动时&#xff0c;弹窗报错&#xff1a;加密狗未安装或检测异常&#xff1b;  问题原因&#xff1a;安装VM 的时候未选择软加密&#xff0c;选择了加密狗驱动&#xff0c;此时要使用软授权就出现了此现象。  解决方法&#xff1a; ① 首先确认软加密驱动正确安装…

单片机 VS 嵌入式LInux (学习方法)

linux 嵌入式开发岗位需要掌握Linux的主要原因之一是&#xff0c;许多嵌入式系统正在向更复杂、更功能丰富的方向发展&#xff0c;需要更强大的操作系统支持。而Linux作为开源、稳定且灵活的操作系统&#xff0c;已经成为许多嵌入式系统的首选。以下是为什么嵌入式开发岗位通常…

机器学习-10-神经网络python实现-从零开始

文章目录 总结参考本门课程的目标机器学习定义从零构建神经网络手写数据集MNIST介绍代码读取数据集MNIST神经网络实现测试手写的图片 带有反向查询的神经网络实现 总结 本系列是机器学习课程的系列课程&#xff0c;主要介绍基于python实现神经网络。 参考 BP神经网络及pytho…

数据挖掘实验(Apriori,fpgrowth)

Apriori&#xff1a;这里做了个小优化&#xff0c;比如abcde和adcef自连接出的新项集abcdef&#xff0c;可以用abcde的位置和f的位置取交集&#xff0c;这样第n项集的计算可以用n-1项集的信息和数字本身的位置信息计算出来&#xff0c;只需要保存第n-1项集的位置信息就可以提速…

去哪儿网开源的一个对应用透明,无侵入的Java应用诊断工具

今天 V 哥给大家带来一款开源工具Bistoury&#xff0c;Bistoury 是去哪儿网开源的一个对应用透明&#xff0c;无侵入的java应用诊断工具&#xff0c;用于提升开发人员的诊断效率和能力。 Bistoury 的目标是一站式java应用诊断解决方案&#xff0c;让开发人员无需登录机器或修改…

microk8s拉取pause镜像卡住

前几天嫌服务器上镜像太多占空间&#xff0c;全部删掉了&#xff0c;今天看到 microk8s 更新了 1.30 版本&#xff0c;果断更新&#xff0c;结果集群跑不起来了。 先通过 microk8s.kubectl get pods --all-namespaces 命令看看 pod 状态。 如上图可以看到&#xff0c;所有的业…

物联网通信中NB-IoT、Cat.1、Cat.M该如何选择?

物联网通信中NB-IoT、Cat.1、Cat.M该如何选择? 参考链接:物联网通信中NB-IoT、Cat.1、Cat.M该如何选择?​​ 在我们准备设计用于大规模联网的物联网设备时,选择到适合的LTE IoT标准将是我们遇到的难点。这是我们一开始设计产品方案就需要解决的一个问题,其决定我们设备需…

HarmonyOS ArkUI实战开发-NAPI 加载原理(下)

上一节笔者给大家讲解了 JS 引擎解释执行到 import 语句的加载流程&#xff0c;总结起来就是利用 dlopen() 方法的加载特性向 NativeModuleManager 内部的链接尾部添加一个 NativeModule&#xff0c;没有阅读过上节文章的小伙伴&#xff0c;笔者强烈建议阅读一下&#xff0c;本…

ChatGPT在线网页版(与GPT Plus会员完全一致)

ChatGPT镜像 今天在知乎看到一个问题&#xff1a;“平民不参与内测的话没有账号还有机会使用ChatGPT吗&#xff1f;” 从去年GPT大火到现在&#xff0c;关于GPT的消息铺天盖地&#xff0c;真要有心想要去用&#xff0c;途径很多&#xff0c;别的不说&#xff0c;国内GPT的镜像…