第N8周:使用Word2vec实现文本分类

文章目录

  • 一、数据预处理
    • 1.加载数据
    • 2.构建词典
    • 3.生成数据批次和迭代器
  • 二、模型构建
    • 1.搭建模型
    • 2.初始化模型
    • 3.定义训练与评估函数
  • 三、训练模型
    • 1.拆分数据集并运行模型
  • 四、总结

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、数据预处理

1.加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type=‘cpu’)

import pandas as pd 

# 加载自定义中文数据
train_data = pd.read_csv('train.csv', sep='\t', header=None)
train_data.head()
01
0还有双鸭山到淮阴的汽车票吗13号的Travel-Query
1从这里怎么回家Travel-Query
2随便播放一首专辑阁楼里的佛里的歌Music-Play
3给看一下墓王之王嘛FilmTele-Play
4我想看挑战两把s686打突变团竞的游戏视频Video-Play
# 构造数据集迭代器
def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y

x = train_data[0].values[:]
# 多类标签的one-shot展开
y = train_data[1].values[:]

2.构建词典

from gensim.models import Word2Vec
import numpy as np

# 训练Word2vec浅层模型
w2v = Word2Vec(vector_size = 100,   # 是指特征向量的维度,默认为100
               min_count = 3)       # 可以对字典做截断,词频少于min_count次数的单词会被丢弃掉,默认值为5

w2v.build_vocab(x)
w2v.train(x,
          total_examples=w2v.corpus_count,
          epochs=20)

(2732785, 3663560)

Word2Vec可以直接训练模型,一步到位。这里分了三步
●第一步构建一个空模型
●第二步使用 build_vocab 方法根据输入的文本数据 x 构建词典。build_vocab 方法会统计输入文本中每个词汇出现的次数,并按照词频从高到低的顺序将词汇加入词典中。
●第三步使用 train 方法对模型进行训练,total_examples 参数指定了训练时使用的文本数量,这里使用的是 w2v.corpus_count 属性,表示输入文本的数量
如果一步到位的话代码为:

w2v = Word2Vec(x, vector_size=100, min_count=3, epochs=20)
# 将文本转化为向量
def average_vec(text):
    vec = np.zeros(100).reshape((1,100))
    for word in text:
        try:
            vec += w2v.wv[word].reshape((1,100))
        except KeyError:
            continue
    return vec

# 将词向量保存为Ndarray
x_vec = np.concatenate([average_vec(z) for z in x])

# 保存Word2Vec模型及词向量
w2v.save('w2v_model.pkl')
train_iter = coustom_data_iter(x_vec, y)
len(x), len(x_vec)

(12100, 12100)

label_name = list(set(train_data[1].values[:]))
print(label_name)

[‘Travel-Query’, ‘Radio-Listen’, ‘Alarm-Update’, ‘FilmTele-Play’, ‘TVProgram-Play’, ‘HomeAppliance-Control’, ‘Calendar-Query’, ‘Audio-Play’, ‘Video-Play’, ‘Other’, ‘Music-Play’, ‘Weather-Query’]

3.生成数据批次和迭代器

text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)
text_pipeline("你在干嘛")

array([[ 0.44942591, 0.37034334, 0.82435736, 0.57583929, -2.19971114,
-0.26266199, 1.54612615, 0.86057729, 0.94607782, -0.56024504,
-1.2855403 , -3.96268934, 1.00411272, -0.78717487, 0.11495599,
1.7602468 , 2.57005858, -2.04502518, 4.77852516, -1.15009709,
2.75658896, -0.7439712 , -0.50604325, 0.23402849, -0.85734205,
-0.64015828, -1.63281712, -1.22751366, 2.32347407, -2.94733901,
1.86662954, 1.20093471, -0.22566201, -0.02635491, 1.06643996,
0.17282215, -0.57236505, 3.87719914, -2.36707568, 1.28222315,
0.16626818, 0.52857486, 0.2673108 , 1.32945235, -0.51124085,
0.68514908, 0.87900299, -0.9519761 , -2.69660458, 1.78133809,
-0.16500359, -2.11181024, -1.16635181, 1.22090494, -0.76275884,
-0.01114198, 0.42615444, -1.23754779, 0.07603779, -0.04253516,
1.32692097, -1.66303211, 2.16462026, -0.9799156 , -0.9070952 ,
0.87778991, -1.08169729, 0.92559687, 0.64850095, 0.20967194,
0.26563513, 1.03787032, 2.3587795 , -0.7511736 , 0.74099658,
-0.15902402, -2.69873536, 0.13621271, 1.08319706, -0.18128317,
-1.8476568 , -0.67964274, -2.43600948, 2.98213428, -1.72624808,
-0.87052085, 2.28517788, -1.87188464, -0.26412555, -0.37503278,
1.51758769, -1.25159131, 0.87080194, 0.85611653, 0.85986885,
-0.60930844, -0.11496616, 0.66294981, -2.06530389, 0.11790894]])

label_pipeline("Travel-Query")

0

from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list= [], []

    for (_text, _label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))

        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)
        text_list.append(processed_text)

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)

    return text_list.to(device), label_list.to(device)

# 数据加载器,调用示例
dataloader = DataLoader(train_iter,
                        batch_size=8,
                        shuffle=False,
                        collate_fn=collate_batch)
        

二、模型构建

1.搭建模型

from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, num_class):
        super(TextClassificationModel, self).__init__()
        self.fc = nn.Linear(100, num_class)

    def forward(self, text):
        return self.fc(text)

2.初始化模型

num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)

3.定义训练与评估函数

import time

def train(dataloader):
    model.train() # 切换到训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (text, label) in enumerate(dataloader):
        predicted_label = model(text)

        optimizer.zero_grad() # grad属性归零
        loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward() # 反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪
        optimizer.step() # 每一步自东更新

        # 记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches'
                  '| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count, train_loss/total_count))

            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval() # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0
      
    with torch.no_grad():
        for idx, (text,label) in enumerate(dataloader):
            predicted_label = model(text)
    
            loss = criterion(predicted_label, label) # 计算loss值
            # 记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count


三、训练模型

1.拆分数据集并运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

# 超参数
EPOCHS = 10 
LR = 5
BATCH_SIZE = 64

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset,
                                          [int(len(train_dataset)*0.8), int(len(train_dataset)*0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | '
          'valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,
                                           time.time() - epoch_start_time,
                                           val_acc,val_loss,lr))

    print('-' * 69)

| epoch 1 | 50/ 152 batches| train_acc 0.737 train_loss 0.02661
| epoch 1 | 100/ 152 batches| train_acc 0.814 train_loss 0.01831
| epoch 1 | 150/ 152 batches| train_acc 0.829 train_loss 0.01839
---------------------------------------------------------------------
| epoch 1 | time: 0.32s | valid_acc 0.803 valid_loss 0.026 | lr 5.000000
---------------------------------------------------------------------
| epoch 2 | 50/ 152 batches| train_acc 0.831 train_loss 0.01848
| epoch 2 | 100/ 152 batches| train_acc 0.852 train_loss 0.01733
| epoch 2 | 150/ 152 batches| train_acc 0.836 train_loss 0.01894
---------------------------------------------------------------------
| epoch 2 | time: 0.26s | valid_acc 0.783 valid_loss 0.029 | lr 5.000000
---------------------------------------------------------------------
| epoch 3 | 50/ 152 batches| train_acc 0.877 train_loss 0.01177
| epoch 3 | 100/ 152 batches| train_acc 0.904 train_loss 0.00813
| epoch 3 | 150/ 152 batches| train_acc 0.897 train_loss 0.00820
---------------------------------------------------------------------
| epoch 3 | time: 0.26s | valid_acc 0.877 valid_loss 0.010 | lr 0.500000
---------------------------------------------------------------------
| epoch 4 | 50/ 152 batches| train_acc 0.896 train_loss 0.00757
| epoch 4 | 100/ 152 batches| train_acc 0.903 train_loss 0.00654
| epoch 4 | 150/ 152 batches| train_acc 0.899 train_loss 0.00722
---------------------------------------------------------------------
| epoch 4 | time: 0.26s | valid_acc 0.888 valid_loss 0.009 | lr 0.500000
---------------------------------------------------------------------
| epoch 5 | 50/ 152 batches| train_acc 0.903 train_loss 0.00619
| epoch 5 | 100/ 152 batches| train_acc 0.897 train_loss 0.00631
| epoch 5 | 150/ 152 batches| train_acc 0.897 train_loss 0.00678
---------------------------------------------------------------------
| epoch 5 | time: 0.27s | valid_acc 0.880 valid_loss 0.008 | lr 0.500000
---------------------------------------------------------------------
| epoch 6 | 50/ 152 batches| train_acc 0.900 train_loss 0.00611
| epoch 6 | 100/ 152 batches| train_acc 0.904 train_loss 0.00543
| epoch 6 | 150/ 152 batches| train_acc 0.912 train_loss 0.00522
---------------------------------------------------------------------
| epoch 6 | time: 0.26s | valid_acc 0.888 valid_loss 0.008 | lr 0.050000
---------------------------------------------------------------------
| epoch 7 | 50/ 152 batches| train_acc 0.903 train_loss 0.00555
| epoch 7 | 100/ 152 batches| train_acc 0.919 train_loss 0.00477
| epoch 7 | 150/ 152 batches| train_acc 0.902 train_loss 0.00590
---------------------------------------------------------------------
| epoch 7 | time: 0.26s | valid_acc 0.888 valid_loss 0.008 | lr 0.005000
---------------------------------------------------------------------
| epoch 8 | 50/ 152 batches| train_acc 0.909 train_loss 0.00523
| epoch 8 | 100/ 152 batches| train_acc 0.906 train_loss 0.00561
| epoch 8 | 150/ 152 batches| train_acc 0.911 train_loss 0.00533
---------------------------------------------------------------------
| epoch 8 | time: 0.27s | valid_acc 0.888 valid_loss 0.008 | lr 0.000500
---------------------------------------------------------------------
| epoch 9 | 50/ 152 batches| train_acc 0.914 train_loss 0.00485
| epoch 9 | 100/ 152 batches| train_acc 0.902 train_loss 0.00578
| epoch 9 | 150/ 152 batches| train_acc 0.910 train_loss 0.00554
---------------------------------------------------------------------
| epoch 9 | time: 0.26s | valid_acc 0.888 valid_loss 0.008 | lr 0.000050
---------------------------------------------------------------------
| epoch 10 | 50/ 152 batches| train_acc 0.907 train_loss 0.00569
| epoch 10 | 100/ 152 batches| train_acc 0.911 train_loss 0.00496
| epoch 10 | 150/ 152 batches| train_acc 0.908 train_loss 0.00548
---------------------------------------------------------------------
| epoch 10 | time: 0.28s | valid_acc 0.888 valid_loss 0.008 | lr 0.000005
---------------------------------------------------------------------

test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

模型准确率为:0.8876

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text), dtype=torch.float32)
        print(text.shape)
        output = model(text)
        return output.argmax(1).item()

# ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"

model = model.to("cpu")

print("该文本的类别是:%s" %label_name[predict(ex_text_str, text_pipeline)])

torch.Size([1, 100])
该文本的类别是:Travel-Query

四、总结

本周主要学了使用word2vec实现文本分类,其中主要了解了训练word2vec浅层模型,同时也更加深入地学习了梯度裁剪。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/943986.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

面试场景题系列:设计一致性哈希系统

为了实现横向扩展,在服务器之间高效和均匀地分配请求/数据是很重要的。一致性哈希是为了达成这个目标而被广泛使用的技术。首先,我们看一下什么是重新哈希问题。 1 重新哈希的问题 如果你有n个缓存服务器,常见的平衡负载的方法是使用如下哈希…

SqlSugar配置连接达梦数据库集群

安装达梦数据库时,会自动在当前操作系统中创建dm_svc.conf文件,可以在其中配置集群信息,不同操作系统下的文件位置如下图所示:   dm_svc.conf文件内的数据分为全局配置区域、服务配置区域,以参考文献1中的示例说明&…

scss配置全局变量报错[sass] Can‘t find stylesheet to import.

路径没有错误,使用别名即可 后又提示Deprecation Warning: Sass import rules are deprecated and will be removed in Dart Sass 3.0.0. 将import改为use 使用时在$前添加全局变量所在文件,即variable.

UGUI源码分析 --- UI的更新入口

首先所有的UI组件都是添加到画布(Canvas)显示的,所以首先要从Canvas入手,通过搜索脚本函数以及使用Profiler查看UI的函数的执行,定位到了willRenderCanvases函数 打开UI的文件夹, 通过搜索willRenderCanvas…

音视频入门知识(二)、图像篇

⭐二、图像篇 视频基本要素:宽、高、帧率、编码方式、码率、分辨率 ​ 其中码率的计算:码率(kbps)=文件大小(KB)*8/时间(秒),即码率和视频文件大小成正比 YUV和RGB可相互转换 ★YUV(原始数据&am…

电脑配置maven-3.6.1版本

不要使用太高的版本。 apache-maven-3.6.1-bin.zip 下载这个的maven压缩包 使用3.6.1版本。 解压缩放在本地软甲目录下面: 配置系统环境变量 在系统环境下面配置MAVEN_HOME 点击path 新增一条 在cmd中输入 mvn -v 检查maven的版本 配置阿里云镜像和本地的仓库 …

Python基础语法知识——数据类型的查询、数据类型转化

今天第一次学习python,之前学习过C,感觉学习起来还可以,就是刚用的时候有点手残,想的是python代码,结果写出来就是C,本人决定每天抽出时间写点。同时继续更新NX二次开发专栏学习,话不多说,晚上的…

Boost之log日志使用

不讲理论,直接上在程序中可用代码: 一、引入Boost模块 开发环境:Visual Studio 2017 Boost库版本:1.68.0 安装方式:Nuget 安装命令: #只安装下面几个即可 Install-package boost -version 1.68.0 Install…

C语言初阶习题【17】求N的阶乘( 递归和非递归实现)

1.题目 2.分析 非递归 需要用到循环,n个数就是循环n次,每次和之前的乘起来 例如 5的阶乘 就是 5*4 *3 *2 *1 循环1到5 。需要一个变量来接收每次的结果 注意这个地方是乘,所以要从1 开始,sum 也需要是1而不是0 for(i 1&#xf…

云效流水线自动化部署web静态网站

云效流水线部署静态网站 背景新建流水线配置流水线运行流水线总结 背景 配置流水线以前,每次更新导航网站都要登进去宝塔后台,删掉旧的目录和文件,再上传最新的文件,太麻烦啦 网上的博客基本都是分享vue项目,这一篇是…

【开源项目】数字孪生化工厂—开源工程及源码

飞渡科技数字孪生化工厂管理平台,基于自研孪生引擎,将物联网IOT、人工智能、大数据、云计算等技术应用于化工厂,为化工厂提供实时数据分析、工艺优化、设备运维等功能,助力提高生产效率以及提供安全保障。 通过可视化点位标注各厂…

SpringCloud整合skywalking实现链路追踪和日志采集

1.部署skywalking https://blog.csdn.net/qq_40942490/article/details/144701194 2.添加依赖 <!-- 日志采集 --><dependency><groupId>org.apache.skywalking</groupId><artifactId>apm-toolkit-logback-1.x</artifactId><version&g…

Linux下Nvidia显卡GPU开启驱动持久化

GPU开启驱动持久化的原因 GPU 驱动一直处于加载状态&#xff0c; 减少运行程序时驱动加载的延迟。不开启该模式时&#xff0c;在程序每次调用完 GPU 后&#xff0c; GPU 驱动都会被卸载&#xff0c;下次调用时再重新加载&#xff0c; 驱动频繁卸载加载&#xff0c; GPU 频繁被…

图像处理-Ch4-频率域处理

Ch4 频率域处理(Image Enhancement in Frequency Domain) FT &#xff1a;将信号表示成各种频率的正弦信号的线性组合。 频谱&#xff1a; ∣ F ( u , v ) ∣ [ R 2 ( u , v ) I 2 ( u , v ) ] 1 2 |F(u, v)| \left[ R^2(u, v) I^2(u, v) \right]^{\frac{1}{2}} ∣F(u,v)…

虚拟化 | Proxmox VE 8.x 开源的虚拟化平台快速上手指南

[ 知识是人生的灯塔,只有不断学习,才能照亮前行的道路 ] 0x00 简介说明 前言简述 描述:作为一个爱折腾的IT打工佬,时刻以学习各类新技术新知识为目标,这不正好有一台部署了VMware vSphere ESXi 虚拟化环境的服务器,由于正好安装其系统的磁盘有坏道,经常导致使用 ESXi 异…

rocketmq-push模式-消费侧重平衡-类流程图分析

1、观察consumer线程 使用arthas分析 MQClientFactoryScheduledThread 定时任务线程 定时任务线程&#xff0c;包含如下任务&#xff1a; 每2分钟更新nameServer列表 每30秒更新topic的路由信息 每30秒检查broker的存活&#xff0c;发送心跳请求 每5秒持久化消费队列的offset…

使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器实现可迭代式数据集

2023 年 11 月&#xff0c;Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元&#xff08;数据集和数据加载器&#xff09;的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理…

[2003].第2-01节:关系型数据库表及SQL简介

所有博客大纲 后端学习大纲 MySQL学习大纲 1.数据库表介绍&#xff1a; 1.1.表、记录、字段 1.E-R&#xff08;entity-relationship&#xff0c;实体-联系&#xff09;模型中有三个主要概念是&#xff1a; 实体集 、 属性 、 联系集2.一个实体集&#xff08;class&#xff09…

wps透视数据表

1、操作 首先选中你要的行字段表格 -> 插入 -> 透视数据表 -> 拖动行值&#xff08;部门&#xff09;到下方&#xff0c;拖动值&#xff08;包裹数量、运费&#xff09;到下方 2、删除 选中整个透视数据表 -> delete 如图&#xff1a;

Python-流量分析常用工具脚本(Tshark,pyshark,scapy)

免责声明&#xff1a;本文仅作分享~ 目录 wireshark scapy 例&#xff1a;分析DNS流量 检查数据包是否包含特定协议层&#xff08;过滤&#xff09; 获取域名 例&#xff1a;提取 HTTP 请求中的 Host 信息 pyshark 例&#xff1a;解析 HTTP 请求和响应 例&#xff1a;分…