第N6周:使用Word2vec实现文本分类

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings
#忽略警告信息
warnings.filterwarnings("ignore")
# win10系统
device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
device

import pandas as pd
# 加载自定义中文数据
train_data= pd.read_csv('./data/train2.csv',sep='\t',header=None)
train_data.head()

# 构造数据集迭代器
def coustom_data_iter(texts,labels):
    for x,y in zip(texts,labels):
        yield x,y
x = train_data[0].values[:]
#多类标签的one-hot展开
y = train_data[1].values[:]


from gensim.models.word2vec import Word2Vec
import numpy as np
#训练word2Vec浅层神经网络模型
w2v=Word2Vec(vector_size=100#是指特征向量的维度,默认为100。
              ,min_count=3)#可以对字典做截断。词频少于min_count次数的单词会被丢弃掉,默认为5

w2v.build_vocab(x)
w2v.train(x,total_examples=w2v.corpus_count,epochs=20)


# 将文本转化为向量
def average_vec(text):
    vec =np.zeros(100).reshape((1,100))
    for word in text:
        try:
            vec +=w2v.wv[word].reshape((1,100))
        except KeyError:
                continue
    return vec
#将词向量保存为Ndarray
x_vec= np.concatenate([average_vec(z)for z in x])
#保存Word2Vec模型及词向量
w2v.save('data/w2v_model.pk1')


train_iter= coustom_data_iter(x_vec,y)
len(x),len(x_vec)

label_name =list(set(train_data[1].values[:]))
print(label_name)


text_pipeline =lambda x:average_vec(x)
label_pipeline =lambda x:label_name.index(x)

text_pipeline("你在干嘛")
label_pipeline("Travel-Query")


from torch.utils.data import DataLoader
def collate_batch(batch):
    label_list,text_list=[],[]
    for(_text,_label)in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text),dtype=torch.float32)
        text_list.append(processed_text)
        label_list = torch.tensor(label_list,dtype=torch.int64)
        text_list = torch.cat(text_list)
    return text_list.to(device),label_list.to(device)
# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,
shuffle =False,
collate_fn=collate_batch)



from torch import nn
class TextclassificationModel(nn.Module):

    def __init__(self,num_class):
        super(TextclassificationModel,self).__init__()
        self.fc = nn.Linear(100,num_class)
    def forward(self,text):
        return self.fc(text)


num_class =len(label_name)
vocab_size =100000
em_size=12
model= TextclassificationModel(num_class).to(device)




import time
def train(dataloader):
    model.train()#切换为训练模式
    total_acc,train_loss,total_count =0,0,0
    log_interval=50
    start_time= time.time()

    for idx,(text,label)in enumerate(dataloader):
        predicted_label= model(text)
        # grad属性归零
        optimizer.zero_grad()
        loss=criterion(predicted_label,label)#计算网络输出和真实值之间的差距,label
        loss.backward()
        #反向传播
        torch.nn.utils.clip_grad_norm(model.parameters(),0.1)#梯度裁剪
        optimizer.step()#每一步自动更新
        #记录acc与loss
        total_acc+=(predicted_label.argmax(1)==label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval==0 and idx>0:
            elapsed =time.time()-start_time
            print('Iepoch {:1d}I{:4d}/{:4d} batches'
            '|train_acc {:4.3f} train_loss {:4.5f}'.format(epoch,idx,len(dataloader),total_acc/total_count,train_loss/total_count))
            total_acc,train_loss,total_count =0,0,0

            start_time = time.time()
def evaluate(dataloader):
    model.eval()#切换为测试模式
    total_acc,train_loss,total_count =0,0,0
    with torch.no_grad():
        for idx,(text,label)in enumerate(dataloader):
            predicted_label= model(text)
            loss = criterion(predicted_label,label)# 计算loss值
            # 记录测试数据
            total_acc+=(predicted_label.argmax(1)== label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    return total_acc/total_count,train_loss/total_count




from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS=10#epoch
LR=5 #学习率
BATCH_SIZE=64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer= torch.optim.SGD(model.parameters(),lr=LR)
scheduler=torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.1)
total_accu = None
# 构建数据集
train_iter= coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_,split_valid_= random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])
train_dataloader =DataLoader(split_train_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_,batch_size=BATCH_SIZE,
shuffle=True,collate_fn=collate_batch)
for epoch in range(1,EPOCHS+1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc,val_loss = evaluate(valid_dataloader)
    # 获取当前的学习率
    lr =optimizer.state_dict()['param_groups'][0]['1r']
    if total_accu is not None and total_accu>val_acc:
      scheduler.step()
    else:
     total_accu = val_acc
    print('-'*69)
    print('|epoch {:1d}|time:{:4.2f}s |'
    'valid_acc {:4.3f} valid_loss {:4.3f}I1r {:4.6f}'.format(epoch,
    time.time()-epoch_start_time,
    val_acc,val_loss,lr))

    print('-'*69)


# test_acc,test_loss =evaluate(valid_dataloader)
# print('模型准确率为:{:5.4f}'.format(test_acc))
#
#
# def predict(text,text_pipeline):
#     with torch.no_grad():
#         text = torch.tensor(text_pipeline(text),dtype=torch.float32)
#         print(text.shape)
#         output = model(text)
#         return output.argmax(1).item()
# # ex_text_str="随便播放一首专辑阁楼里的佛里的歌"
# ex_text_str="还有双鸭山到淮阴的汽车票吗13号的"
# model=model.to("cpu")
# print("该文本的类别是:%s"%label_name[predict(ex_text_str,text_pipeline)])

以上是文本识别基本代码

输出:

[[-0.85472693  0.96605204  1.5058695  -0.06065784 -2.10079319 -0.12021151
   1.41170089  2.00004494  0.90861696 -0.62710127 -0.62408304 -3.80595499
   1.02797993 -0.45584389  0.54715634  1.70490362  2.33389823 -1.99607518
   4.34822938 -0.76296186  2.73265275 -1.15046433  0.82106878 -0.32701646
  -0.50515595 -0.37742117 -2.02331601 -1.365334    1.48786476 -1.6394971
   1.59438308  2.23569647 -0.00500725 -0.65070192  0.07377997  0.01777986
  -1.35580809  3.82080549 -2.19764423  1.06595343  0.99296588  0.58972518
  -0.33535255  2.15471306 -0.52244038  1.00874437  1.28869729 -0.72208139
  -2.81094289  2.2614549   0.20799019 -2.36187895 -0.94019454  0.49448857
  -0.68613767 -0.79071895  0.47535057 -0.78339124 -0.71336574 -0.27931567
   1.0514895  -1.76352624  1.93158554 -0.85853558 -0.65540617  1.3612217
  -1.39405773  1.18187538  1.31730198 -0.02322496  0.14652854  0.22249881
   2.01789951 -0.40144247 -0.39880068 -0.16220299 -2.85221207 -0.27722868
   2.48236791 -0.51239379 -1.47679498 -0.28452797 -2.64497767  2.12093259
  -1.2326943  -1.89571355  2.3295732  -0.53244872 -0.67313893 -0.80814604
   0.86987564 -1.31373079  1.33797717  1.02223087  0.5817025  -0.83535647
   0.97088164  2.09045361 -2.57758138  0.07126901]]
6

输出结果并非为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/500955.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[flink 实时流基础]源算子和转换算子

文章目录 1. 源算子 Source1. 从集合读2. 从文件读取3. 从 socket 读取4. 从 kafka 读取5. 从数据生成器读取数据 2. 转换算子基本转换算子(map/ filter/ flatMap) 1. 源算子 Source Flink可以从各种来源获取数据,然后构建DataStream进行转换…

hcia datacom课程学习(5):MAC地址与arp协议

1.MAC地址 1.1 含义与作用 (1)含义: mac地址也称物理地址,是网卡设备在数据链路层的地址,全世界每一块网卡的mac地址都是唯一的,出厂时烧录在网卡上不可更改 (2)作用&#xff1a…

OKCC的API资源管理平台怎么用?

API资源管理平台,重点是“资源”管理平台,不是API接口管理平台。 天天讯通推出的API资源管理平台,类似昆石的VOS系统,区别是VOS是SIP资源管理系统,我们的API资源管理平台是API资源管理系统(AXB、AX、回拨AP…

科技下乡:数字乡村改变乡村生活方式

在科技飞速发展的时代,数字化、信息化浪潮正以前所未有的速度席卷全球。在这场科技革命中,乡村不再是滞后的代名词,而是成为了数字乡村建设的热土。科技下乡,让数字乡村成为了改变乡村生活方式的重要力量。 一、科技下乡&#xf…

京东云8核16G服务器配置租用优惠价格1198元1年、4688元三年

京东云轻量云主机8核16G服务器租用优惠价格1198元1年、4688元三年,配置为8C16G-270G SSD系统盘-5M带宽-500G月流量,华北-北京地域。京东云8核16G服务器活动页面 yunfuwuqiba.com/go/jd 活动链接打开如下图: 京东云8核16G服务器优惠价格 京东云…

操作系统OS Chapter1

操作系统OS 一、概念和功能1.概念2.功能3.目标 二、特征1.并发2.共享3.虚拟4.异步 三、发展四、运行机制五、中断和异常1.中断的作用2.中断的类型3.中断机制的原理 六、系统调用七、操作系统结构八、操作系统引导九、虚拟机 一、概念和功能 1.概念 操作系统(OS&…

harbor api v2.0

harbor api v2.0 v2.0 v2.0 “harbor api v2.0”与原来区别较大,此处harbor也做了https。另外,通过接口拿到的数据也是只能默认1页10个,所以脚本根据实际情况一页页的抓取数据 脚本主要用于统计repo、image,以及所有镜像的tag数&…

HTML网站的概念

目录 前言: 1.什么是网页: 2.什么是网站: 示例: 3.服务器: 总结: 前言: HTML也称Hyper Text Markup Language,意思是超文本标记语言,同时HTML也是前端的基础&…

IF= 13.4| 当eDNA遇上机器学习法

近日,凌恩生物客户重庆医科大学在《Water Research》(IF 13.4)发表研究论文“Supervised machine learning improves general applicability of eDNA metabarcoding for reservoir health monitoring”。该研究主要介绍了一种基于eDNA的机器学…

mysql的主从配置

MySQL主从复制是一种常见的数据库复制技术,用于实现数据在一个主数据库服务器和一个或多个从数据库服务器之间的同步。在主从配置中,主服务器负责接收和处理写操作,然后将这些变更通过binlog日志传播到从服务器,从服务器根据主服务…

【MySQL】7.MHA高可用配置及故障切换

什么是MHA MHA(MasterHigh Availability)是一套优秀的MySQL高可用环境下故障切换和主从复制的软件 mha用于解决mysql的单点故障问题; 出现故障时,mha能在0~30秒内自动完成故障切换; 并且能在故障切换过程中&#xff0…

《让你的时间多一倍》逃离时间陷阱,你没有自己想的那么懒 - 三余书屋 3ysw.net

让你的时间多一倍 今天我们来阅读法比安奥利卡尔的作品《让你的时间多一倍》。或许你会心生疑虑,这本书是否又是一本沉闷的时间管理指南?但我要告诉你的是,尽管时间管理这个话题已经为大众所熟知,这本书却为我们揭示了一个全新的…

【Roadmap to learn LLM】Large Language Models in Five Formulas

by Alexander Rush Our hope: reasoning about LLMs Our Issue 文章目录 Perpexity(Generation)Attention(Memory)GEMM(Efficiency)用矩阵乘法说明GPU的工作原理 Chinchilla(Scaling)RASP(Reasoning)结论参考资料 the five formulas perpexity —— generationattention —— m…

PyCharm中配置PyQt5并添加外部工具

Qt Designer、PyUIC和PyRcc是Qt框架下的三个重要工具,总的来说,这三个工具各司其职,相辅相成,能显著提升Qt开发的速度与效率。 Qt Designer:是一个用于创建图形用户界面的工具,可轻松构建复杂的用户界面。…

matlab及其在数字信号处理中的应用001:软件下载及安装

目录 一,matlab的概述 matlab是什么 matlab适用于的问题 matlab的易扩展性 二,matlab的安装 1,解压所有压缩文件 2,解压镜像压缩文件 3,运行setup.exe 4,开始安装 5,不要运行软件…

EasyBoss ERP上线实时数据大屏,Shopee本土店铺数据实时监测

近日,灵隐寺PPT汇报用上数据大屏疯狂刷屏,有做东南亚本土电商的老板发现这种数据大屏的模式可以很好地展现店铺运营状况。 所以就有老板来问:EasyBoss能不能也上线实时数据大屏的功能?没问题!立马安排! 要有…

BasicVSR++模型转JIT并用c++libtorch推理

BasicVSR模型转JIT并用clibtorch推理 文章目录 BasicVSR模型转JIT并用clibtorch推理安装BasicVSR 环境1.下载源码2. 新建一个conda环境3. 安装pytorch4. 安装 mim 和 mmcv-full5. 安装 mmedit6. 下载模型文件7. 测试一下能否正常运行 转换为JIT模型用c libtorch推理效果 安装Ba…

只出现一次的数字 II

题目链接 只出现一次的数字 II 题目描述 注意点 nums中,除某个元素仅出现一次外,其余每个元素都恰出现三次设计并实现线性时间复杂度的算法且使用常数级空间来解决此问题 解答思路 本题与只出现一次的数字的数字类似,区别是重复的数字会…

深度学习InputStreamReader类

咦咦咦,各位小可爱,我是你们的好伙伴——bug菌,今天又来给大家普及Java SE相关知识点了,别躲起来啊,听我讲干货还不快点赞,赞多了我就有动力讲得更嗨啦!所以呀,养成先点赞后阅读的好…

SpringMVC注解及使用规则

文章目录 前言一、SpringMVC注解是什么?二、使用步骤1.注解使用2创建JSP3 SpringMVC视图1. 逻辑视图(Logical View)2. 物理视图(Physical View)区别和关系 4 SpringMVC注解总结 总结 前言 提示:这里可以添…