【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务实战

【Pytorch】学习记录分享9——PyTorch新闻数据集文本分类任务

      • 1. 认为主流程code
      • 2. NLP 对话和预测基本均属于分类任务详细见
      • 3. Tensorborad

1. 认为主流程code

import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
from tensorboardX import SummaryWriter

###制定参数 --model TextRNN
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
args = parser.parse_args()


if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  #TextCNN, TextRNN,
    if model_name == 'FastText':
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)
    config = x.Config(dataset, embedding)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # train
    config.n_vocab = len(vocab)
    model = x.Model(config).to(config.device)
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
    if model_name != 'Transformer':
        init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter, test_iter,writer)

RNN


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.lstm = nn.LSTM(config.embed, config.hidden_size, config.num_layers,
                            bidirectional=True, batch_first=True, dropout=config.dropout)
        self.fc = nn.Linear(config.hidden_size * 2, config.num_classes)

    def forward(self, x):
        x, _ = x
        out = self.embedding(x)  # [batch_size, seq_len, embeding]=[128, 32, 300]
        out, _ = self.lstm(out)
        out = self.fc(out[:, -1, :])  # 句子最后时刻的 hidden state
        return out

在这里插入图片描述
TextRNN h_t 为RNN提取出来的特征

2. NLP 对话和预测基本均属于分类任务详细见

Pytorch学习记录分享9-PyTorch新闻数据集文本分类任务实战

3. Tensorborad

数据可视化操作 code repo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/274813.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序

如果您有 Android 设备,您可能会将个人和专业的重要文件保存在设备的 SD 卡上。这些文件包括照片、视频、文档和各种其他类型的文件。您绝对不想丢失这些文件,但当您的 SD 卡损坏时,数据丢失是不可避免的。 幸运的是,您不需要这样…

Redis的缓存过期淘汰策略

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理、数据库技术🔥如果感觉博主的文章还不错的…

docker学习(十八、network介绍)

[TOC]添加链接描述 首先,我们要知道什么是 Docker 网络。简单来说,它就是 Docker 中用于实现容器间通信的一个东西。 network相关内容: docker学习(十八、network介绍) docker学习(十九、network使用示例br…

COSCon'23 主论坛回顾:基金会的治理模式

在开源软件和开源社区中,开源基金会扮演着至关重要的角色,为开源项目和社区提供了一种结构化和有组织的支持,有助于确保开源项目的成功、可持续性和广泛采用。她们充当了协调者、中介和支持者的角色,有助于促进开源技术的发展和推…

面向对象(高级)知识点强势总结!!!

文章目录 一、知识点复习1-关键字:static1、知识点2、重点 2-单例模式(或单子模式)1、知识点2、重点 3-理解main()方法1、知识点2、重点 4-类的成员之四:代码块1、知识点2、重点 5-关键字:final1、知识点2、重点 6-关键…

【Unity自制手册】基于Unity中物体移动相关方法和API集锦(动图详解)

👨‍💻个人主页:元宇宙-秩沅 👨‍💻 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍💻 本文由 秩沅 原创 👨‍💻 收录于专栏:uni…

数字时代跨境电商营销大变革:海外网红营销的力量与影响

随着全球化的推进和数字技术的不断发展,跨境电商行业迎来了一场营销变革的浪潮。在这个过程中,一种新的营销方式崭露头角,那就是海外网红营销。海外网红以其独特的个人魅力和影响力,成为跨境电商推广的重要力量,为品牌…

2023年下半年软件设计师上午真题及答案解析

1.在双核处理器中,双核是指( )。 A.执行程序时有两条指令流水线并行工作 B.在一个CPU中集成两个运算核心以提高运算能力 C.利用超线程技术实现的多任务并行处理 D.在主板上设置两个独立的 CPU 以提高处理能力 2.某文件管理系统在磁盘上建立了位示图(bitmap)&am…

JOSEF约瑟 温度继电器 JUC-1M (≥20℃断开)常开型

JUC系列温度继电器 JUC-1M型超小型密封温度继电器 JUC-2M型超小型密封温度继电器 继电器JUC-027M/2531H-III-G温度继电器 JUC-1M 10C常开温度继电器 JUC-1M 105C温度继电器 用途 超小型温控开关系接触感应式密封温度继电器,具有体积小、重量轻、控温精度高等特点&…

Ubuntu中fdisk磁盘分区并挂载、扩容逻辑卷

Ubuntu中fdisk磁盘分区并挂载、扩容逻辑卷 一:fdisk磁盘分区并挂载1.查看磁盘分区信息2.分区3.强制系统重新读取分区(避免重启系统)4.格式化分区5.创建挂载目录6.设置开机自动挂载:7.验证并自动挂载(执行了该命令不需要重启系统)8.查看挂载007.异常情况处…

【机器学习】西瓜书第6章支持向量机课后习题6.1参考答案

【机器学习】西瓜书学习心得及课后习题参考答案—第6章支持向量机 1.试证明样本空间中任意点x到超平面(w,b)的距离为式(6.2)。 首先,直观解释二维空间内点到直线的距离: 由平面向量的有关知识,可得: 超平面的法向量为 w w w&am…

【Spring实战】09 MyBatis Generator

文章目录 1. 依赖2. 配置文件3. 生成代码4. 详细介绍 generatorConfig.xml5. 代码详细总结 Spring MyBatis Generator 是 MyBatis 官方提供的一个强大的工具,它能够基于数据库表结构自动生成 MyBatis 持久层的代码,包括实体类、Mapper 接口和 XML 映射文…

【i阿极送书——第六期】《YOLO目标检测》

系列文章目录 作者:i阿极 作者简介:数据分析领域优质创作者、多项比赛获奖者:博主个人首页 😊😊😊如果觉得文章不错或能帮助到你学习,可以点赞👍收藏📁评论📒…

基于飞浆OCR的文本框box及坐标中心点检测JSON格式保存文本

OCR的文本框box及JSON数据保存 需求说明 一、借助飞浆框出OCR识别的文本框 二、以圆圈形式标出每个框的中心点位置 三、以JSON及文本格式保存OCR识别的文本 四、以文本格式保存必要的文本信息 解决方法 一、文本的坐标来自飞浆的COR识别 二、借助paddleocr的draw_ocr画出…

听说!Art-DAQ实现了与LabVIEW的无缝连接

前言 阿尔泰科技与时俱进,推出Art-DAQ程序,与LabVIEW无缝连接,形成系统平台体系。持续不断地获取行业新技术,完善自主知识产权产品的研发,为客户提供优质服务。 什么是Labview? 从产品的角度来看&#x…

【信息安全原理】——入侵检测与网络欺骗(学习笔记)

📖 前言:在网络安全防护领域,防火墙是保护网络安全的一种最常用的设备。网络管理员希望通过在网络边界合理使用防火墙,屏蔽源于外网的各类网络攻击。但是,防火墙由于自身的种种限制,并不能阻止所有攻击行为…

原生微信小程序如何动态配置主题颜色及如何调用子组件的方法

一、最终效果 二、步骤 1、在初始化进入项目时,获取当前主题色 2、把主题色定义成全局变量(即在app.js中设置) 3、tabBar也需要定义全局变量,在首页时需要重新赋值 三、具体实现 1、app.js onLaunch () {//获取主题数据this.set…

SkyWalking UI 修改发布Nginx

文章目录 SkyWalking UI修改图标修改路由发布到Nginx添加认证修改路由模式vite.config.ts添加baseNginx配置 SkyWalking UI skywalking-booster-ui下载地址 修改图标 替换 logo.svg 修改路由 router - data - index.ts 发布到Nginx 添加认证 # 安装 yum install -y h…

Ubuntu安装K8S的dashboard(管理页面)

原文网址:Ubuntu安装k8s的dashboard(管理页面)-CSDN博客 简介 本文介绍Ubuntu安装k8s的dashboard(管理页面)的方法。 Dashboard的作用有:便捷操作、监控、分析、概览。 相关网址 官网地址:…