N6 word2vec文本分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

上周学习了训练word2vec模型,这周进行相关实战

1. 导入所需库和设备配置
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib, warnings

warnings.filterwarnings("ignore")  # 忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import pandas as pd
2. 加载数据
train_data = pd.read_csv('./train.csv', sep='\t', header=None)
print(train_data)
3. 数据预处理
def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y

x = train_data[0].values[:]
y = train_data[1].values[:]

from gensim.models.word2vec import Word2Vec
import numpy as np

w2v = Word2Vec(vector_size=100, min_count=3)
w2v.build_vocab(x)
w2v.train(x, total_examples=w2v.corpus_count, epochs=20)
  • 定义自定义数据迭代器coustom_data_iter
  • 提取文本和标签数据。
  • 使用Word2Vec训练词向量模型,设置词向量维度为100,最小词频为3。
def average_vec(text):
    vec = np.zeros(100).reshape((1, 100))
    for word in text:
        try:
            vec += w2v.wv[word].reshape((1, 100))
        except KeyError:
            continue
    return vec

x_vec = np.concatenate([average_vec(z) for z in x])
w2v.save('w2v_model.pkl')

train_iter = coustom_data_iter(x_vec, y)
print(len(x), len(x_vec))
label_name = list(set(train_data[1].values[:]))
print(label_name)

text_pipeline = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline("你在干嘛"))
print(label_pipeline("Travel-Query"))
  • 定义函数average_vec,将文本转换为词向量的平均值。
  • 将所有文本转换为词向量并保存Word2Vec模型。
  • 打印文本和向量的数量,以及所有标签的名称。
  • 定义文本和标签的预处理函数text_pipelinelabel_pipeline
4. 数据加载器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list = [], []

    for (_text, _label) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)
        text_list.append(processed_text)

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)

    return text_list.to(device), label_list.to(device)

dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
  • 定义函数collate_batch,将批次中的文本和标签转换为张量。
  • 创建数据加载器dataloader
5. 定义模型
class TextClassificationModel(nn.Module):

    def __init__(self, num_class):
        super(TextClassificationModel, self).__init__()
        self.fc = nn.Linear(100, num_class)

    def forward(self, text):
        return self.fc(text)

num_class = len(label_name)
model = TextClassificationModel(num_class).to(device)
  • 定义文本分类模型TextClassificationModel,包含一个全连接层。
  • 初始化模型,设置输出类别数。
6. 训练和评估函数
import time

def train(dataloader):
    model.train()
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 50
    start_time = time.time()

    for idx, (text, label) in enumerate(dataloader):
        predicted_label = model(text)

        optimizer.zero_grad()
        loss = criterion(predicted_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:1d} | {:4d}/{:4d} batches | train_acc {:4.3f} train_loss {:4.5f}'.format(
                epoch, idx, len(dataloader), total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        for idx, (text, label) in enumerate(dataloader):
            predicted_label = model(text)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)

    return total_acc / total_count, train_loss / total_count
  • 定义训练函数train和评估函数evaluate
7. 训练模型
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

EPOCHS = 10
LR = 5
BATCH_SIZE = 64

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None

train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset, [int(len(train_dataset) * 0.8), int(len(train_dataset) * 0.2)])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    lr = optimizer.state_dict()['param_groups'][0]['lr']

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:1d} | time: {:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(
        epoch, time.time() - epoch_start_time, val_acc, val_loss, lr))
    print('-' * 69)

test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))
  • 定义超参数并初始化损失函数、优化器和学习率调度器。
  • 创建数据集并进行训练集和验证集的划分。
  • 训练模型并在每个epoch后进行验证。
8. 预测函数
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text), dtype=torch.float32)
        print(text.shape)
        output = model(text)
        return output.argmax(1).item()

ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
model = model.to("cpu")
print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])
  • 定义预测函数predict,将文本转换为张量并使用模型进行预测。
  • 使用示例文本进行预测并输出结果。

结果

在这里插入图片描述

总结

这周学习了通过word2vec文本分类,包括数据加载、预处理、模型训练、评估和预测。进一步加深了对word2vec的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/790556.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

04OLED简介和调试方法

OLED简介和调试方法 调试方式串口调试显示屏调试其他调试方法总结: OLED简介硬件电路OLED驱动函数 keil调试模式进入方法keil调试界面窗口简单功能说明更加强大的功能 调试方式 电脑想看什么变量可以直接打印到屏幕,但是单片机很多时候由于成本和电路结构…

多协议网关设计架构与实现,支持 RS485/232、CAN、M-Bus、MQTT、TCP 等工业协议接入(附代码示例)

一、项目概述 1.1 背景 随着物联网技术的快速发展,越来越多的设备需要接入网络进行数据交互。然而,不同设备往往采用不同的通信协议,例如工业现场常用的Modbus、CAN、电力载波等,以及物联网领域常用的MQTT、TCP/IP等&#xff0c…

C++类与对象-基础篇

目录 一、什么是类 1.1 语法定义 1.2 访问限定符 1.3 类域 二、类的实例化 2.1 什么是实例化 2.2 类的大小 三、this指针 3.1 引入 3.2 this指针的使用 一、什么是类 1.1 语法定义 class 类名 {}; 说明 类似于C语言中的结构体,括号后分号不能丢类内成员可…

SSM学习5:AOP

简介 面向切面编程,一种编程范式,指导开发者如何组织程序结构。可以在不经打原始设计的基础上为其进行功能增强。 入门案例 案例:在接口执行前输出当前系统时间 开发模式:XML 或者 注解 思路分析: 导入坐标&#xf…

数学建模·模糊评价法

模糊评价法 一种解决评价问题或者得出最佳方案的方法具体定义 三集:因素集,评语集和权重集,通过模拟矩阵的处理得到最合理的评语具体步骤 因素集 因素集的确定不难,难在对分级评价时,对因素集的分级有技巧评语集 …

TIOBE 7月编程排行榜出炉!Python再次出圈!

又到了周三,本周有过半了,大家好呀! 每月的TIOBE编程排行榜都是技术社区关注的焦点,作为编程语言流行度的晴雨表,它反映了行业趋势和技术走向。2024年7月的榜单揭晓了一个重要变化:Python再次登上榜首&…

如何批量更改很多个文件夹里的文件名中包含文件夹名?

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

RedisTemplate 中序列化方式辨析

在Spring Data Redis中,RedisTemplate 是操作Redis的核心类,它提供了丰富的API来与Redis进行交互。由于Redis是一个键值存储系统,它存储的是字节序列,因此在使用RedisTemplate时,需要指定键(Key&#xff09…

论文学习_Getafix: learning to fix bugs automatically

1. 引言 研究背景:现代生产代码库极其复杂并且不断更新。静态分析器可以帮助开发人员发现代码中的潜在问题(在本文的其余部分中称为错误),这对于在这些大型代码库中保持高代码质量是必要的。虽然通过静态分析尽早发现错误是有帮助的,但修复这些错误的问题在实践中仍然主要…

“连阳抓妖”,连阳抓主升浪 后市能成妖

指标名 连阳抓妖通达信副图指标 是否收费 免费 格式 源码 注:公式不带有未来函数 ► 图表展示 使用技巧 本指标源自网红博主的精辟讲解,专为捕捉“潜力妖股”设计。其核心逻辑简单而高效,仅需满足四大核心条件,即可锁定前期未大涨个股的…

基于51单片机的五路抢答器Protues仿真设计

目录 一、设计背景 二、实现功能 三、仿真演示 四、源程序(部分) 一、设计背景 近年来随着科技的飞速发展,单片机的应用正在不断的走向深入。本文阐述了基于51单片机的五路抢答器设计。本设计中,51单片机充当了核心控制器的角…

Python实现傅里叶级数可视化工具

Python实现傅里叶级数可视化工具 flyfish 有matlab实现,我没matlab,我有Python,所以我用Python实现。 整个工具的实现代码放在最后,界面使用PyQt5开发 起源 傅里叶级数(Fourier Series)由法国数学家和物理学家让-巴…

[激光原理与应用-106]:南京科耐激光-激光焊接-焊中检测-智能制程监测系统IPM介绍 - 9 - 图解常见的焊接缺陷/缺欠分类

目录 前言: 1、焊接缺陷的类型 2、焊接缺陷的危害 3、结论 一、功能性缺陷 1.1 虚焊:最重要的非视觉检测的缺陷 1.虚焊的定义 2.虚焊的成因 3.虚焊的危害 4.虚焊的检测与解决 二、成型性缺陷 2.1 黑爆缺陷 1、黑爆缺陷的定义与外观 2、黑爆…

深度优先搜索(所有可达路径)

参考题目:所有可达路径 题目描述 给定一个有 n 个节点的有向无环图,节点编号从 1 到 n。请编写一个函数,找出并返回所有从节点 1 到节点 n 的路径。每条路径应以节点编号的列表形式表示。 输入描述 第一行包含两个整数 N,M&…

红日靶场----(三)2.漏洞利用

上期的通过一句话木马实现对目标主机的持久后门 我使用的是蚁剑,蚁剑安装及使用参考: 下载地址: GitHub - AntSwordProject/AntSword-Loader: AntSword 加载器 安装即使用: 1. 快速入门 语雀 通过YXCMS的后台GETSHELL 利用…

C++第四弹 -- 类与对象(中上) (构造函数 析构函数 拷贝构造函数)

目录 前言构造函数1. 概念2. 特征 析构函数1. 概念2. 特征 拷贝构造函数1. 概念2. 特征 总结 前言 让我们一起揭开 C 对象生命周期管理的神秘面纱,掌握构造函数、析构函数和拷贝构造函数的精髓! 博客主页: 酷酷学!!! 期待更多好文, 点击关注~ 构造函…

Linux系统中磁盘管理LVM与挂载

Linux系统中磁盘管理LVM与挂载 本文以属于Linux系统基本概念,如果以查找教程教程,解决问题为主,只需要查看本文后半部分。如需要系统性学习请查看本文前半部分。 本文操作极容易导致主机无法自动重启,请慎重操作。操作前务必要进…

新手教学系列——crontab 使用不当引发的服务器性能问题

起因及症状 最近,我们的一台服务器随着运行时间的增加,逐渐出现了压力过大的问题。具体表现为数据库连接数飙升至 4000+,Redis 频繁超时,系统报错文件打开数过多等。针对这些问题,我们逐一检查了数据库连接池、Redis 连接池以及系统的 ulimit 配置,但都未能找到问题的根…

ROS服务通信自定义srv

服务通信自定义srv 流程:创建ROS功能包按照固定格式创建srv文件编译配置文件编译生成中间文件 流程: srv 文件内的可用数据类型与 msg 文件一致,且定义 srv 实现流程与自定义 msg 实现流程类似,需查阅msg文件的可以浏览ROS话题通信流程自定义数据msg格式…

7月报名 | 海克斯康CAEfatigue疲劳分析培训

您好!感谢您长期以来对优飞迪科技与海克斯康的关注与支持。我们诚邀您参加海克斯康CAEfatigue疲劳分析培训,特邀海克斯康原厂讲师将通过培训帮助您了解CAEfatigue的功能并使用其进行疲劳分析的过程、参数设置以及软件操作方法和技巧,学会使用…