第N4周:使用Word2vec实现文本分类

目录

  • 二、数据预处理
    • 1.加载数据
    • 2.构建词典
    • 3.生成数据批次和迭代器
  • 二、模型构建
    • 1.搭建模型
    • 2.初始化模型
    • 3.定义训练与评估函数
  • 三、训练模型
    • 1.拆分数据集并运行模型
    • 2.测试指定数据

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制# 一、课题背景和开发环境
    📌第N4周:使用Word2vec实现文本分类📌

Python 3.8.12
gensim4.3.1
numpy
1.21.5 -> 1.24.3
portalocker2.7.0
pytorch
1.8.1+cu111
torchtext==0.9.1
📌本周任务:📌

结合Word2Vec文本内容(第1列)预测文本标签(第2列)
尝试根据第2周的内容独立实现,尽可能的不看本文的代码
进一步了解并学习Word2Vec
任务说明:
本次将加入Word2vec使用PyTorch实现中文文本分类,Word2Vec则是其中的一种词嵌入方法,是一种用于生成词向量的浅层神经网络模型,由Tomas Mikolov及其团队于2013年提出。**Word2Vec通过学习大量文本数据,将每个单词表示为一个连续的向量,这些向量可以捕捉单词之间的语义和句法关系。**更详细的内容可见训练营内的NLP基础知识,数据示例如下:
在这里插入图片描述

二、数据预处理

1.加载数据

warnings.filterwarnings("ignore")  #忽略警告信息
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device =', device)


print('STEP.1', '-' * 19)
''' 加载自定义中文数据 '''
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
print(train_data.head())
# 构造数据集迭代器
def coustom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y
x = train_data[0].values[:]
# 多类标签的one-hot展开
y = train_data[1].values[:]

2.构建词典

print('STEP.2', '-' * 19)
''' 构建词典 '''
# 训练Word2Vec浅层神经网络模型
w2v = Word2Vec(vector_size=100,  #是指特征向量的维度,默认为100。
               min_count=3)      #可以对字典做截断. 词频少于min_count次数的单词会被丢弃掉, 默认值为5。
w2v.build_vocab(x)
w2v.train(x,
          total_examples=w2v.corpus_count,
          epochs=20)

# 将文本转化为向量
def average_vec(text):
    vec = np.zeros((1,100))
    for word in text:
        try:
            vec += w2v.wv[word].reshape((1,100))
        except KeyError:
            continue
    return vec

# 将词向量保存为ndarray
x_vec = np.concatenate([average_vec(z) for z in x])
print('len(x) =', len(x), 'len(x_vec) =', len(x_vec))
# 保存Word2Vec模型及词向量
w2v.save('output/w2v_model.pkl')
train_iter = coustom_data_iter(x_vec, y)

3.生成数据批次和迭代器

print('STEP.3', '-' * 19)
''' 准备数据处理管道 '''
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline  = lambda x: average_vec(x)
label_pipeline = lambda x: label_name.index(x)
print('你在干嘛', text_pipeline('你在干嘛'))
print('Travel-Query', label_pipeline('Travel-Query'))
''' 生成数据批次和迭代器 '''
def collate_batch(batch):
    label_list, text_list = [], []
    #offsets = [0]
    for (_text, _label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.float32)
        text_list.append(processed_text)
        #offsets.append(processed_text.size(0))
    
    label_list = torch.tensor(label_list, dtype=torch.int64)  # torch.Size([64])
    text_list = torch.cat(text_list)  # 若干tensor组成的列表变成一个tensor
    #offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # torch.Size([64])
    return text_list.to(device), label_list.to(device)  # , offsets.to(device)
# 数据加载器
#dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)


二、模型构建

1.搭建模型

print('STEP.4', '-' * 19)
''' 搭建文本分类模型 '''
class TextClassificationModel(nn.Module):
    def __init__(self, num_class):
        super(TextClassificationModel, self).__init__()
        self.fc = nn.Linear(100, num_class)
    
    def forward(self, text):
        output = self.fc(text)
        return output

2.初始化模型

print('STEP.5 Initialize', '-' * 19)
''' 初始化实例 '''
num_class  = len(label_name)
vocab_size = 100000  # 词典大小
emsize     = 12      # 嵌入的维度
model      = TextClassificationModel(num_class).to(device)

3.定义训练与评估函数

''' 训练函数 '''
def train(dataloader):
    model.train()  # 训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 50
    start_time = time.time()
    
    for idx, (text, label) in enumerate(dataloader):
        optimizer.zero_grad()  # grad属性归零
        predited_label = model(text)
        loss = criterion(predited_label, label)
        loss.backward()  # 反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # 梯度裁剪
        optimizer.step()  # 每一步自动更新
        # 记录acc与loss
        total_acc += (predited_label.argmax(1) == label).sum().item()
        train_loss  += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches, train_acc {:8.3f} train_loss {:8.3f}'.format(epoch, idx, len(dataloader), total_acc/total_count, train_loss/total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()


''' 评估函数 '''
def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0
    
    with torch.no_grad():
        for idx, (text, label) in enumerate(dataloader):
            predited_label = model(text)
            loss = criterion(predited_label, label)  # 计算loss值
            # 记录测试数据
            total_acc += (predited_label.argmax(1) == label).sum().item()
            train_loss  += loss.item()
            total_count += label.size(0)
    return total_acc/total_count, train_loss/total_count

三、训练模型

1.拆分数据集并运行模型

''' 开始训练 '''
if __name__ == '__main__':
    # 超参数(Hyperparameters)
    EPOCHS     = 10  # epoch
    LR         = 5   # learning rate
    BATCH_SIZE = 64  # batch size for training
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
    total_accu = None
    # 构建数据集
    train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
    train_dataset = list(train_iter)
    # 划分数据集
    num_train = int(len(train_dataset) * 0.8)
    split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
    # 加载数据集
    train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱
    valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    
    for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train(train_dataloader)
        accu_val, loss_val = evaluate(valid_dataloader)
        # 获取当前的学习率
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | '
              'valid_acc {:8.3f} valid_loss {:8.3f} | lr {:8.6f}'.format(epoch, time.time()-epoch_start_time, accu_val, loss_val, lr))
        print('-' * 59)
    
    torch.save(model.state_dict(), 'output\\model_TextClassification.pth')
   
    print('Checking the results of test dataset.')
    accu_test, loss_test = evaluate(valid_dataloader)
    print('test accuracy {:8.3f}, test loss {:8.3f}'.format(accu_test, loss_test))

2.测试指定数据

''' 预测函数 '''
def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text), dtype=torch.float32)
        print(text.shape)
        output = model(text)
        return output.argmax(1).item()


''' 以下是预测 '''
if __name__=='__main__':
    model.load_state_dict(torch.load('output\\model_TextClassification.pth'))
    #label_name = ['Alarm-Update', 'Other', 'Audio-Play', 'Calendar-Query', 'HomeAppliance-Control', 'Radio-Listen', 'Travel-Query', 'Video-Play', 'TVProgram-Play', 'FilmTele-Play', 'Weather-Query', 'Music-Play']
    
    ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
    #ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"
    model = model.to("cpu")
    
    print("该文本的类别是:%s" % label_name[predict(ex_text_str, text_pipeline)])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/36340.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot 系列2 -- 配置文件

目录 1. 配置文件的作用 2. 配置文件的格式 3. properties 配置文件说明 3.1 properties 基本语法 3.2 读取配置文件 3.3 properties 缺点 4.yml 配置文件说明 4.1 yml 基本语法 4.2 yml 使用进阶 4.2.1 yml 配置不同数据类型及 null 4.2.2 yml 配置读取 4.2.3 注意…

DPWWN1靶场详解

DPWWN1靶场详解 首先还是nmap -sP 192.168.102.0/24扫描到ip地址,然后对这个ip进行一个单独的扫描,发现这个靶场有一个mysql数据库,猜测可能会用到sql注入,但是没用到。 ip登陆到网页发现并没有什么可利用的 唯一的切入点也就数…

Java 动态规划 Leetcode 62. 不同路径

代码展示&#xff1a; class Solution {public int uniquePaths(int m, int n) {//定义dp数组//二维数组多增加一行一列&#xff0c;方便对数组进行初始化int[][]dpnew int[m1][n1];//初始化dp[0][1]1;//填充数组for(int i1;i<m;i){for(int j1;j<n;j){dp[i][j]dp[i-1][j…

基于springboot+Redis的前后端分离项目(七)-【黑马点评】

&#x1f381;&#x1f381;资源文件分享 链接&#xff1a;https://pan.baidu.com/s/1189u6u4icQYHg_9_7ovWmA?pwdeh11 提取码&#xff1a;eh11 发布笔记&#xff0c;点赞&#xff0c;点赞排行 达人探店1、达人探店-发布探店笔记2、 达人探店-查看探店笔记3、 达人探店-点赞功…

《网络安全标准实践指南》(72页)

导读 摘要&#xff1a;为指导网络数据安全风险评估工作&#xff0c;发现数据安全隐患&#xff0c;防范数据安全风险&#xff0c;依据《中华人民共和国网络安全法》《中华人民共和国数据安全法》《中华人民共和国个人信息保护法》等法律法规&#xff0c;参照数据安全相关国家标…

STM32寄存器点亮LED灯

一&#xff1a; 如何寄存器点灯 1&#xff1a;看单片机的原理图 找到LED灯 这个灯是 PB5引脚 看原理图可以看出 让GPIOB5输出低电平 就能点亮那么我们得让打开控制GPIOB5的时钟让GPIOB5 输出模式让GPIOB5低电平 二&#xff1a;看中文参考手册配置寄存器 2.1&#xff1a;打开管…

【Windows】Redis集群部署

集群是如何进行工作的 Redis采用哈希槽来处理数据与节点之间的映射关系&#xff0c;一个集群共有16384 个哈希槽&#xff0c;每个key通过 CRC16算法计算出一个16bit的值&#xff0c;再对16384取模&#xff0c;得到对应的哈希槽&#xff0c;集群通过维护哈希槽与节点的关系来得…

redis与分布式

主从复制 概念 主从复制&#xff0c;是指将一台Redis服务器的数据&#xff0c;复制到其他的Redis服务器。前者称为主节点(Master)&#xff0c;后者称为从节点(Slave)&#xff0c;数据的复制是单向的&#xff0c;只能由主节点到从节点。Master以写为主&#xff0c;Slave 以读为…

MySQL----MHA高可用

文章目录 一、MHA理论1.1什么是 MHA1.2MHA 的组成1.3MHA 的特点 二、MHA的一主两从部署实验设计故障修复步骤&#xff1a; 一、MHA理论 1.1什么是 MHA MHA&#xff08;MasterHigh Availability&#xff09;是一套优秀的MySQL高可用环境下故障切换和主从复制的软件。 MHA 的出…

【Django】Django框架使用指南

Django使用指南 作者简介&#xff1a;嗨~博主目前是长安大学软件工程专硕在读&#x1f4d8;&#xff0c;喜欢钻研一些自己感兴趣的计算机技术&#xff0c;求关注&#x1f609;&#xff01; 框架简介&#xff1a;Django是一个基于Python语言的开源Web应用框架&#xff0c;采用 M…

基于STM32FFT(快速傅里叶变换)音频频谱显示功能实现

+ v hezkz17进数字音频系统研究开发交流答疑 一实验效果 二 设计过程 要用C语言实现STM32频谱显示功能,可以按照以下步骤进行操作: 1 确保已经安装好了适当的开发环境和工具链,例如Keil MDK或者GCC工具链。 2 创建一个新的STM32项目,并选择适合的MCU型号。 3 配置G…

【数据挖掘】时间序列教程【九】

第5章 状态空间模型和卡尔曼滤波 状态空间模型通常试图描述具有两个特征的现象 有一个底层系统具有时变的动态关系,因此系统在时间上的“状态”t 与系统在时间的状态t−1有关 .如果我们知道系统在时间上的状态t−1 ,那么我们就有了我们需要知道的一切,以便对当时的状态进行推…

Android Zygote 启动流程

和你一起终身学习&#xff0c;这里是程序员Android 经典好文推荐&#xff0c;通过阅读本文&#xff0c;您将收获以下知识点: Android系统包含netd、servicemanager、surfaceflinger、zygote、media、installd、bootanimation 等基本服务&#xff0c;具体作用请看下图。 Android…

更开放、更高性能、更具规模,闪马智能布局AGI时代

7月6日&#xff0c;2023世界人工智能大会&#xff08;WAIC 2023&#xff09;在上海盛大开幕。本届大会以“智联世界 生成未来”为主题&#xff0c;聚焦通用人工智能发展&#xff0c;共话产业新未来。 8日上午&#xff0c;由上海闪马智能科技有限公司&#xff08;下称“闪马智能…

el-form实现其中一个填写即可的校验

<el-formref"form":model"formData":rules"formRules"label-width"130px"><el-row :gutter"24"><el-col :span"12"><el-form-item label"司机姓名 :" prop"driverName"…

西电_矩阵论_学习笔记

文章目录 【 第一章 线性空间 】【 第二章 范数 】【 第三章 矩阵函数 】【 第四章 矩阵分解 】【 第五章 矩阵特征值估计 】【 第六章 广义逆 】【 考试重点内容总结 】 这是博主2023春季西电所学矩阵论的思维导图&#xff08;软件是幕布&#xff09;&#xff0c;供大家参考&a…

GaussDB OLTP云数据库配套工具DRS

目录 一、前言 二、DRS定义与使用场景 1、DRS定义 2、DRS场景示意图 三、DRS核心功能 1、实时迁移管理 2、实时同步管理 3、备份迁移管理 4、数据订阅管理 5、实时灾备管理 四、小结 一、前言 华为GaussDB云数据库提供了配套的生态工具数据复制服务DRS。 DRS围绕云…

漏洞深度分析 | CVE-2023-36053-Django 表达式拒绝服务

​ 项目介绍 Django 是一个高级 Python Web 框架&#xff0c;鼓励快速开发和简洁、务实的设计。它由经验丰富的开发人员构建&#xff0c;解决了 Web 开发的大部分麻烦&#xff0c;因此您可以专注于编写应用程序&#xff0c;而无需重新发明轮子。它是免费且开源的。 项目地址…

Squid 缓存服务器

Squid 缓存服务器 作为应用层的代理服务软件&#xff0c;Squid 主要提供缓存加速和应用层过滤控制的功能 ☆什么是缓存代理 当客户机通过代理来请求 Web 页面时 指定的代理服务器会先检查自己的缓存&#xff0c;如果缓存中已经有客户机需要访问的页面&#xff0c;则直接将缓…

【考研思维题】【哈希表 || 什么时候用哈希表呢?快速查询的时候】【我们一起60天准备考研算法面试(大全)-第九天 9/60】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客&#xff0c;如有问题交流&#xff0c;欢迎评论区留言&#xff0c;一定尽快回复&#xff01;&#xff08;大家可以去看我的专栏&#xff0c;是所有文章的目录&#xff09;   文章字体风格&#xff1a; 红色文字表示&#…