【Pytorch】学习记录分享10——TextCNN用于文本分类处理

【Pytorch】学习记录分享10——PyTorchTextCNN用于文本分类处理

      • 1. TextCNN用于文本分类
      • 2. 代码实现

1. TextCNN用于文本分类

具体流程:
在这里插入图片描述
在这里插入图片描述

2. 代码实现

# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class Config(object):

    """配置参数"""
    def __init__(self, dataset, embedding):
        self.model_name = 'TextCNN'
        self.train_path = dataset + '/data/train.txt'                                # 训练集
        self.dev_path = dataset + '/data/dev.txt'                                    # 验证集
        self.test_path = dataset + '/data/test.txt'                                  # 测试集
        self.class_list = [x.strip() for x in open(
            dataset + '/data/class.txt').readlines()]                                # 类别名单
        self.vocab_path = dataset + '/data/vocab.pkl'                                # 词表
        self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'        # 模型训练结果
        self.log_path = dataset + '/log/' + self.model_name
        self.embedding_pretrained = torch.tensor(
            np.load(dataset + '/data/' + embedding)["embeddings"].astype('float32'))\
            if embedding != 'random' else None                                       # 预训练词向量
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 设备

        self.dropout = 0.5                                              # 随机失活
        self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)                         # 类别数
        self.n_vocab = 0                                                # 词表大小,在运行时赋值
        self.num_epochs = 20                                            # epoch数
        self.batch_size = 128                                           # mini-batch大小
        self.pad_size = 32                                              # 每句话处理成的长度(短填长切)
        self.learning_rate = 1e-3                                       # 学习率
        self.embed = self.embedding_pretrained.size(1)\
            if self.embedding_pretrained is not None else 300           # 字向量维度
        self.filter_sizes = (2, 3, 4)                                   # 卷积核尺寸
        self.num_filters = 256                                          # 卷积核数量(channels数)


'''Convolutional Neural Networks for Sentence Classification'''


class Model(nn.Module):
    def __init__(self, config):
        super(Model, self).__init__()
        if config.embedding_pretrained is not None:
            self.embedding = nn.Embedding.from_pretrained(config.embedding_pretrained, freeze=False)
        else:
            self.embedding = nn.Embedding(config.n_vocab, config.embed, padding_idx=config.n_vocab - 1)
        self.convs = nn.ModuleList(
            [nn.Conv2d(1, config.num_filters, (k, config.embed)) for k in config.filter_sizes])
        self.dropout = nn.Dropout(config.dropout)
        self.fc = nn.Linear(config.num_filters * len(config.filter_sizes), config.num_classes)

    def conv_and_pool(self, x, conv):
        x = F.relu(conv(x)).squeeze(3)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

    def forward(self, x):
        #print (x[0].shape)
        out = self.embedding(x[0])
        out = out.unsqueeze(1)
        out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)
        out = self.dropout(out)
        out = self.fc(out)
        return out

该代码对应上述的图像中的模块实现,CNN用于处理文本数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/296661.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

canal本地搭建以及运行

具体的文档可参考官网文档:https://github.com/alibaba/canal/wiki canal [kənl],译意为水道/管道/沟渠,主要用途是基于 MySQL 数据库增量日志解析,提供增量数据订阅和消费 工作原理 canal 模拟 MySQL slave 的交互协议&#x…

01-线程池项目背景:C++的数据库操作

从0开始学习C与数据库的联动 1.原始方式-使用MySQL Connector/C 提供的API查询 1.1 数据库预操作 我的本地电脑上有mysql数据库,里面预先创建了一个database名叫chat,用户名root,密码password。 1.2 Visual Studio预操作 在Windows上使用…

vue3中使用echarts:tooltip的trigger为axis tooltip不显示问题

vue3中使用echarts时,tooltip的trigger设置为axis时formatter不触发 tooltip: {trigger: "axis",formatter: function (params) {console.log("params", params);},axisPointer: {type: "shadow", // 阴影指示器}, },解决办法&#…

10分钟设置免费海外远程桌面使用Amazon Lightsail服务的免费套餐轻松搭建远程桌面

本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 亚马逊云科技开发者社区, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技官方渠道。 目录 前言 使用教程 启动 Amazon Lightsail 实例 配置远程桌面 启动远程桌面 使…

圆通速递单号查询入口,筛选出指定某天签收的单号

随着电商和物流行业的飞速发展,快递单号的管理也成了一个让人头疼的问题。如何快速筛选、整理这些快递单号,成为了提高生活和工作效率的关键。而【快递批量查询高手】的出现,正好可以巧妙的解决上面的问题,下面就来具体看看这款软…

Docker一键极速安装Nacos,并配置数据库!

1 部署方式 1.1 DockerHub javaedgeJavaEdgedeMac-mini ~ % docker run --name nacos \ -e MODEstandalone \ -e JVM_XMS128m \ -e JVM_XMX128m \ -e JVM_XMN64m \ -e JVM_MS64m \ -e JVM_MMS64m \ -p 8848:8848 \ -d nacos/nacos-server:v2.2.3 a624c64a1a25ad2d15908a67316d…

【洛谷千题详解】P5706 【深基2.例8】再分肥宅水

只需要用t/n即可。 AC代码&#xff1a; #include<bits/stdc.h> using namespace std; int main() {float a;int b;cin>>a>>b;double ca/b;printf("%.3f\n",c);cout<<b*2<<endl;return 0; }

VuePress部署到GitHub Pages

一、git push自动部署 1、创建用于工作流的文件 在项目根目录下创建一个用于 GitHub Actions 的工作流 .yml 文件 name: docson:# 每当 push 到 main 分支时触发部署push:branches: [main]# 手动触发部署workflow_dispatch:jobs:docs:runs-on: ubuntu-lateststeps:- uses: a…

问题 E: 便利店

题目描述 天宝来到便利店想买些饮料。便利店有各种型号的瓶装饮料售卖&#xff0c;不同型号的饮料卖不同的价格。1瓶0.25升的卖A元&#xff0c;1瓶0.5升的饮料卖B元&#xff0c;1瓶1升的卖C元&#xff0c;1瓶2升的卖D元。便利店里每种饮料都是无限供应。 天宝要买N升的饮料&a…

逻辑斯蒂回归

逻辑斯蒂回归简介 逻辑斯蒂回归&#xff08;Logistic Regression&#xff09;是一个非常经典的算法&#xff0c;虽然被称为回归&#xff0c;但其实际上是分类模型&#xff0c;并常用于二分类。因为通过逻辑回归模型&#xff0c;我们得到的计算结果是0-1之间的连续数字&#xff…

FreeRTOS 实时操作系统第九讲 - 链表 (数据结构)

一、链表简述 链表是一种物理存储单元上非连续、非顺序的存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的指针链接次序实现的。链表由一系列节点&#xff08;链表中每一个元素称为节点&#xff09;组成&#xff0c;节点可以在运行时动态生成。每个节点包括两个部分&…

虾皮、Lazada店铺流量怎么提升?自养号优势及测评系统如何搭建?

虾皮、Lazada是东南亚地区最大的购物平台之一&#xff0c;吸引了大量的买家和卖家。在竞争激烈的虾皮市场上&#xff0c;如何提升店铺的流量成为许多卖家关注的问题。以下是关于如何提升虾皮、Lazada店铺流量的一些建议。 一、店铺流量怎么提升? 首先&#xff0c;进行优质的…

C#编程-使用集合

使用集合 您学习了如何使用数组来有效地存储和操作相似类型额数据。但是,以下限制于数组的使用相关联: 您必须在声明时定义数组的大小。您必须编写代码以对数组执行标准操作,如排序。让我们思考一个示例。假设您想要存储在组织工作的五个雇员的姓名。您可以使用以下语句来声…

java基于ssm的线上选课系统的设计与实现论文

摘 要 在如今社会上&#xff0c;关于信息上面的处理&#xff0c;没有任何一个企业或者个人会忽视&#xff0c;如何让信息急速传递&#xff0c;并且归档储存查询&#xff0c;采用之前的纸张记录模式已经不符合当前使用要求了。所以&#xff0c;对学生选课信息管理的提升&#x…

无人直播源码/技术源头开发/直播贴片技术源头

无人直播源码/技术源头开发/直播贴片&#xff0c;无人直播&#xff0c;无人实景直播。动态切片功能&#xff0c;自动回复功能&#xff0c;压低打断主播音&#xff0c;支持短视频各大主流平台开播。 搭建无人直播源码需要以下步骤&#xff1a; 1. 确定直播平台和工具&#xff1…

基于孔雀优化算法的航线规划

MATLAB2020a下正常运行 上传明细-CSDN创作中心

XSKY SDS 产品率先获得 OceanBase V4 新版本认证

近日&#xff0c;北京奥星贝斯科技有限公司&#xff08;简称&#xff1a;OceanBase&#xff09;与北京星辰天合科技股份有限公司&#xff08;简称&#xff1a;XSKY 星辰天合&#xff09;顺利完成产品兼容性认证。 XSKY 的高性能全闪存储以及混闪存储&#xff0c;与 OceanBase V…

大数据平台Bug Bash大扫除最佳实践

一、背景 随着越来越多的"新人"在日常工作以及大促备战中担当大任&#xff0c;我们发现仅了解自身系统业务已不能满足日常系统开发运维需求。为此&#xff0c;大数据平台部门组织了一次Bug Bash活动&#xff0c;既能提升自己对兄弟产品的理解和使用&#xff0c;又能…

【IPC通信--信号】

信号处理函数 • 信号发送函数 – kill(), sigqueue(), raise(), alarm(), setitimer(), pause() &#xff0c; abort() • 信号安装函数 – signal(), sigaction() • 信号集操作函数 – sigemptyset(), sigfillset(), sigaddset(), sigdelset(), sigismember() 信号发送函数—…

xgboost对密西西比数据集csv文件进行预测

代码&#xff1a; # 导入需要的库 from sklearn.preprocessing import LabelEncoder import matplotlib.pyplot as plt import pandas as pd import xgboost as xgb from sklearn.model_selection import train_test_split from sklearn.metrics import confusion_matrix, cla…