【NLP练习】Pytorch文本分类入门

Pytorch文本分类入门

🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍖 原作者:K同学啊 | 接辅导、项目定制

文本分类流程图

一、前期准备

1. 环境安装

确保已经安装torchtext与portalocker库

2. 加载数据

#加载数据
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")   #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

输出:

device(type='cpu')
from torchtext.datasets import AG_NEWS

train_iter = AG_NEWS(split='train')   #加载AG NEWS数据集

3. 构建词典

#构建词典
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')   #返回分词器函数

def yield_tokens(data_iter):
    for _,text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter),
                                 specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])   #设置默认索引,如果找不到单词,则会选择默认索引

vocab(['here','is','an','example'])

输出:

[475, 21, 30, 5297]
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
text_pipeline('here is the an example')

输出:


[475, 21, 2, 30, 5297]

4. 生成数据批次和迭代器

# 生成数据批次和迭代器
from torch.utils.data import DataLoader

def collate_batch(batch):
    label_list, text_list, offsets = [],[],[0]
    for(_label, _text) in batch:
        #标签列表
        label_list.append(label_pipeline(_label))
        #文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        #偏移量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)       #返回维度dim中输入元素的累计和
    return label_list.to(device), text_list.to(device), offsets.to(device)

#数据加载器
dataloader = DataLoader(
    train_iter,
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_batch
)

二、准备模型

1. 定义模型

定义TextClassificationModel模型,首先对文本进行嵌入,然后对句子嵌入后的结果进行均值聚合
在这里插入图片描述

from torch import nn

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel,self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,      #词典大小
                                        embed_dim,        # 嵌入的维度
                                        sparse=False)     #
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

2. 定义实例

#定义实例
num_class = len(set([label for (label,text) in train_iter]))
vocab_size = len(vocab)
em_size = 64
model = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练函数与评估函数

#定义训练函数与评估函数
import time

def train(dataloader):
    model.train()          #切换为训练模式
    total_acc, train_loss, total_count = 0,0,0
    log_interval = 500
    start_time = time.time()
    for idx, (label, text, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()                             #grad属性归零
        loss = criterion(predicted_label, label)          #计算网络输出和真实值之间的差距,label为真
        loss.backward()                                   #反向传播
        optimizer.step()                                  #每一步自动更新
        #记录acc与loss
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('|epoch{:d}|{:4d}/{:4d} batches|train_acc{:4.3f} train_loss{:4.5f}'.format(
                epoch,
                idx,
                len(dataloader),
                total_acc/total_count,
                train_loss/total_count))
            total_acc,train_loss,total_count = 0,0,0
            staet_time = time.time()

def evaluate(dataloader):
    model.eval()      #切换为测试模式
    total_acc,train_loss,total_count = 0,0,0
    with torch.no_grad():
        for idx,(label,text,offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)
            loss = criterion(predicted_label,label)   #计算loss值
            #记录测试数据
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    
    return total_acc/total_count, train_loss/total_count

三、训练模型

1. 拆分数据集并运行模型

from torch.utils.data.dataset   import random_split
from torchtext.data.functional  import to_map_style_dataset

# 超参数设定
EPOCHS      = 10
LR          = 5
BATCH_SIZE  = 64

#设置损失函数、选择优化器、设置学习率调整函数
criterion   = torch.nn.CrossEntropyLoss()
optimizer   = torch.optim.SGD(model.parameters(), lr = LR)
scheduler   = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma = 0.1)
total_accu  = None

# 加载数据
train_iter, test_iter   = AG_NEWS() 
train_dataset   = to_map_style_dataset(train_iter)
test_dataset    = to_map_style_dataset(test_iter)
num_train       = int(len(train_dataset) * 0.95)

split_train_ , split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])
                                           
train_dataloader    = DataLoader(split_train_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
valid_dataloader    = DataLoader(split_valid_, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)
test_dataloader     = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, collate_fn = collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s | valid_acc {:4.3f} valid_loss {:4.3f}'.format(
        epoch,
        time.time() - epoch_start_time,
        val_acc,
        val_loss))
    print('-' * 69)

输出:

|epoch1| 500/1782 batches|train_acc0.909 train_loss0.00420
|epoch1|1000/1782 batches|train_acc0.909 train_loss0.00431
|epoch1|1500/1782 batches|train_acc0.910 train_loss0.00415
---------------------------------------------------------------------
| epoch 1 | time:17.55s | valid_acc 0.913 valid_loss 0.004
---------------------------------------------------------------------
|epoch2| 500/1782 batches|train_acc0.924 train_loss0.00355
|epoch2|1000/1782 batches|train_acc0.922 train_loss0.00366
|epoch2|1500/1782 batches|train_acc0.917 train_loss0.00376
---------------------------------------------------------------------
| epoch 2 | time:17.58s | valid_acc 0.914 valid_loss 0.004
---------------------------------------------------------------------
|epoch3| 500/1782 batches|train_acc0.929 train_loss0.00329
|epoch3|1000/1782 batches|train_acc0.929 train_loss0.00332
|epoch3|1500/1782 batches|train_acc0.929 train_loss0.00337
---------------------------------------------------------------------
| epoch 3 | time:19.67s | valid_acc 0.892 valid_loss 0.005
---------------------------------------------------------------------
|epoch4| 500/1782 batches|train_acc0.947 train_loss0.00258
|epoch4|1000/1782 batches|train_acc0.946 train_loss0.00257
|epoch4|1500/1782 batches|train_acc0.946 train_loss0.00266
---------------------------------------------------------------------
| epoch 4 | time:18.36s | valid_acc 0.915 valid_loss 0.004
---------------------------------------------------------------------
|epoch5| 500/1782 batches|train_acc0.951 train_loss0.00243
|epoch5|1000/1782 batches|train_acc0.949 train_loss0.00252
|epoch5|1500/1782 batches|train_acc0.947 train_loss0.00256
---------------------------------------------------------------------
| epoch 5 | time:17.92s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
|epoch6| 500/1782 batches|train_acc0.950 train_loss0.00245
|epoch6|1000/1782 batches|train_acc0.950 train_loss0.00246
|epoch6|1500/1782 batches|train_acc0.950 train_loss0.00245
---------------------------------------------------------------------
| epoch 6 | time:18.10s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
|epoch7| 500/1782 batches|train_acc0.950 train_loss0.00245
|epoch7|1000/1782 batches|train_acc0.951 train_loss0.00242
|epoch7|1500/1782 batches|train_acc0.951 train_loss0.00239
---------------------------------------------------------------------
| epoch 7 | time:18.08s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
|epoch8| 500/1782 batches|train_acc0.951 train_loss0.00238
|epoch8|1000/1782 batches|train_acc0.951 train_loss0.00241
|epoch8|1500/1782 batches|train_acc0.955 train_loss0.00228
---------------------------------------------------------------------
| epoch 8 | time:18.75s | valid_acc 0.918 valid_loss 0.004
---------------------------------------------------------------------
|epoch9| 500/1782 batches|train_acc0.952 train_loss0.00234
|epoch9|1000/1782 batches|train_acc0.953 train_loss0.00235
|epoch9|1500/1782 batches|train_acc0.951 train_loss0.00237
---------------------------------------------------------------------
| epoch 9 | time:18.50s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------
|epoch10| 500/1782 batches|train_acc0.951 train_loss0.00234
|epoch10|1000/1782 batches|train_acc0.954 train_loss0.00231
|epoch10|1500/1782 batches|train_acc0.954 train_loss0.00234
---------------------------------------------------------------------
| epoch 10 | time:17.82s | valid_acc 0.917 valid_loss 0.004
---------------------------------------------------------------------

2. 使用测试数据集评估模型

print('Checking the results of test dataset.')
test_acc,test_loss = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(test_acc))

输出:

Checking the results of test dataset.
test accuracy    0.908

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/504550.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【滑动窗口】Leetcode 找到字符串中所有字母异位词

题目解析 438. 找到字符串中所有字母异位词 算法讲解 寻找目标串的异位词&#xff0c;我们使用固定长度的滑动窗口&#xff0c;首先我们判断窗口左右的字符是否存在于目标串中&#xff0c;如果不存在就让窗口滑动&#xff1b;存在的话&#xff0c;我们就把字符丢进Hash中&a…

【JavaSE】类和对象详解(上)

欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 目录 类和对象 类的组成 对类的理解 成员变量的访问和类方法的调用 this 抛出一个问题 this的作用 初始化成员变量 未初始化的成员变量 代码举例 就地初始化 构…

每日一题(leetcode331):验证二叉树的前序序列化——栈

类似消消乐&#xff0c;数字&#xff0b;“#”“#”就可以消成一个“#”&#xff0c;到最后如果栈中只剩一个“#”便说明序列正确。 当然也可以用槽位理解&#xff0c;一个数字出现会消耗一个槽位产生两个槽位&#xff08;即产生一个槽位&#xff09;&#xff0c;一个“#”出现…

题目:安全序列(蓝桥OJ 3432)

问题描述&#xff1a; 题解&#xff1a; #include <bits/stdc.h> using namespace std; using ll long long; const int N 1e6 9, p 1e9 7;int prefix[N],dp[N];int main() {int n, k;cin >> n >> k;dp[0] prefix[0] 1;for(int i 1; i < n; i){i…

iptables 与 firewalld 防火墙

iptables iptables 是一款基于命令行的防火墙策略管理工具 四种防火墙策略&#xff1a; ACCEPT&#xff08;允许流量通过&#xff09; 流量发送方会看到响应超时的提醒&#xff0c;但是流量发送方无法判断流量是被拒绝&#xff0c;还是接收方主机当前不在线 REJECT&#xff08…

STM32的IAP技术,BootLoader

来源 三种下载方式&#xff1a; 1、ICP&#xff1a;ST-Link, 2、ISP: FlyMcu, 3、IAP IAP简介 IAP技术的核心在于BootLoader程序的设计&#xff0c;这段程序预先烧录在单片机中&#xff0c;正常的APP程序可以使用BootLoader程序中的IAP功能写入&#xff0c;也可以两部分代码一…

【JavaEE初阶系列】——常见的锁策略

目录 &#x1f6a9;乐观锁和悲观锁 &#x1f6a9;读写锁和普通互斥锁 &#x1f6a9;轻量级锁和重量级锁 &#x1f6a9;自旋锁和挂起等待锁 &#x1f6a9;公平锁和非公平锁 &#x1f6a9;可重入锁和不可重入锁 &#x1f6a9;关于synchronized的锁策略以及自适应 接下来讲解的锁策…

Springboot之RESTful风格

目录 1、概述&#xff1a; 1.1、传统风格的API&#xff1a; 1.2、RESTful风格的API&#xff1a; 1.3、GET、POST、PUT、DELETE&#xff1a; 2、RESTful风格相关的注解&#xff1a; ①PathVariable&#xff0c;用来获取url中的数据&#xff1b; ②GetMapping&#xff0c;接…

C++教学——从入门到精通 6.ASCII码与字符型

如何把小写字母转换成大写字母呢&#xff1f; 这个问题问的好&#xff0c;首先我们要新学一个类型——char 这个类型就是字符型 再来说说ASCII码 给大家举几个例子 空格————32 0————48 9————57 A————65 Z————90 a————97 z————122 我们…

Cortex-M7中断向量表的重定向

1 前言 系统上电后&#xff0c;PC会指向复位向量&#xff0c;即向量表中的Reset_Handler&#xff0c;而系统就是通过Vector Table Offset Register (VTOR)的值加上4字节来找到复位向量的入口的。 因为地址 0 处应该存储引导代码&#xff0c;所以它通常映射到 Flash 或者是 ROM …

Unity类银河恶魔城学习记录11-14 p116 Thunder strike item effect源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili ThurderStrike_Controller.cs using System.Collections; using System.C…

如何提高小红书笔记的收录率?

在小红书平台上&#xff0c;笔记的收录率是衡量一篇笔记是否受欢迎和有价值的重要因素。为了提高笔记的收录率&#xff0c;有几个关键点需要注意&#xff1a; 1.内容不涉及广告 在发布笔记前要先确保笔记内容不包含任何形式的广告或推广信息。小红书平台对于广告性质的内容有…

关于Ansible的模块②

转载说明&#xff1a;如果您喜欢这篇文章并打算转载它&#xff0c;请私信作者取得授权。感谢您喜爱本文&#xff0c;请文明转载&#xff0c;谢谢。 接《关于Ansible的模块 ①-CSDN博客》&#xff0c;继续学习和梳理Ansible的常用文件类模块 1. copy模块 从当前机器上复制文件到…

关于 Unreal 的各种坐标系、输入与逻辑的转换问题

说明 已知&#xff1a; 在世界原点往 X 轴方向看去&#xff0c;ForwardVector 为 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]&#xff0c;此时的右手边的方向为 [ 0 , 1 , 0 ] [0,1,0] [0,1,0] 手柄摇杆、鼠标移动朝右得到的 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]&#xff0c;朝上得到的是…

代码随想录:字符串5-7

右旋字符串 题目 字符串的右旋转操作是把字符串尾部的若干个字符转移到字符串的前面。给定一个字符串 s 和一个正整数 k&#xff0c;请编写一个函数&#xff0c;将字符串中的后面 k 个字符移到字符串的前面&#xff0c;实现字符串的右旋转操作。 例如&#xff0c;对于输入字符…

货币与利率

货币与利率 货币及其职能什么是货币货币的职能货币带来了什么&#xff1f; 货币形式的演变商品货币代用货币信用货币货币的特性 现代社会货币的表现形式流通中的现金支票存款信用卡储存存款 货币层次划分目的划分标准划分种类我国的货币层次 货币与物价的关系利率什么是利息什么…

算法学习——LeetCode力扣补充篇1

算法学习——LeetCode力扣补充篇1 1365. 有多少小于当前数字的数字 1365. 有多少小于当前数字的数字 - 力扣&#xff08;LeetCode&#xff09; 描述 给你一个数组 nums&#xff0c;对于其中每个元素 nums[i]&#xff0c;请你统计数组中比它小的所有数字的数目。 换而言之&a…

推荐一本牛逼的入门 Python书!,如何试出一个Python开发者真正的水平

本书详细解说了 Python 语言和编程的本质&#xff0c;无论你是否接触过编程语言&#xff0c;只要是 Python 编程的初学者&#xff0c;都可阅读本书。 本书讲解的内容虽然基础&#xff0c;但并不简单。本书提供了 165 幅图表&#xff0c;可以让大家能够轻松地理解并掌握复杂的概…

Taro+vue3 监听当前的页面滚动的距离

1.需求 想实现一个这样的效果 一开始这个城市组件 是透明的 在顶部 的固定定位 当屏幕滑动的时候到一定的距离 将这个固定的盒子 背景颜色变成白色 2.Taro中的滚动 Taro中的滚动 有固定的api 像生命周期一样 这个生命周期是 usePageScroll import Taro, { useDidShow, useP…

外包干了5天,技术退步明显.......

先说一下自己的情况&#xff0c;大专生&#xff0c;18年通过校招进入杭州某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落! 而我已经在一个企业干了四年的功能测…