pytorch集智4-情绪分类器

1 目标

从中文文本中识别出句子里的情绪。和上一章节单车预测回归问题相比,这个问题是分类问题,不是回归问题

2 神经网络分类器

2.1 如何用神经网络分类

第二章节用torch.nn.Sequantial做的回归预测器,输出神经元只有一个。分类器和其区别如下

1 分类器输出单元有几个,就有几个分类

2 分类器各输出单元输出值加和为1(值表示某个预测分类的概率,概率求和为1)

3 回归预测最后一层可以用sigmoid,分类不行,因为sigmoid无法保证分类器输出单元值求和为1,可以用softmax,表达式如下

2.2 分类问题损失函数

分类问题损失函数一般用交叉熵:

N表示样本数量,Ci表示第i个样本的类别标签,y表示每个样本的损失值

3 词袋模型分类器

直白理解:就是统计句子里各词语出现数量,忽略语法,语义等因素

思想:1 将句子转化为词袋,然后可以onehot处理,方便用神经网络预测;2 大概思想:在词袋里寻找和输出相关关键词语,然后统计词语数量,哪个输出类别的敏感词语统计数量高,结果就是哪个输出类别

数据预处理:可以归一化处理,网络计算会更快

3.1 简单文本分类器

背景与数据:获取网购商城顾客评论信息,从评论信息获取句子情绪

神经网络组成:用前馈神经网络,有3层,分别是输入,中间,输出层,输出层数量有两个,分别代表情绪是正面还是负面

3.2 程序实现

3.2.1 数据处理

主要步骤:处理标点,句子分词,建立词表

标点处理:正则,涉及的标点全部删除

句子分词:使用jieba模块对句子分词

建立词表:python字典建立词表


def deal_punc(sentence):
    sentence = re.sub('[\s+\.\!\/_,$%^*(+\"\'“”《》?]+|[+——!,。?、~@#¥%……&*():)]+', '', sentence)
    return sentence


def prepare_data(good_file, bad_file, is_filter=True):
    print('prepare data begin')
    all_words = []
    pos_sen, neg_sen = [], []
    for path, sen in zip((good_file, bad_file), (pos_sen, neg_sen)):
        with open(path, 'r', encoding='utf-8') as f:
            for index, line in enumerate(f):
                if is_filter:
                    line = deal_punc(line)
                
                words = jieba.lcut(line)
                if len(words) > 0:
                    all_words += words
                    sen.append(words)
            print(f'{path} include {index} rows, all words:{len(all_words)}')
    print(f'pos_sen len:{len(pos_sen)}, neg_sen len:{len(neg_sen)}')
            
    diction = {}
    cnt = Counter(all_words)
    for word, freq in cnt.items():
        diction[word] = [len(diction), freq]
    print(f'diction len:{len(diction)}')
    return (pos_sen, neg_sen, diction)

def word2index(word, diction):
    if word in diction:
        value = diction[word][0]
    else:
        value = -1
    return (value)

def index2word(index, diction):
    for w, v in diction.items():
        if v[0] == index:
            return (w)
    return (None)

3.2.2 文本数据向量化

    pos_sen, neg_sen, diction = prepare_data_sample(good_file, bad_file)
    st = sorted([(v[1], w) for w, v in diction.items()])
    datasets, labels, sentences = [], [], []
    '''
    for sens, tag in zip((pos_sen, neg_sen), (0, 1)):
        for sen in sens:
            new_sen = []
            for l in sen:
                if l in diction:
                    new_sen.append(word2index_sample(1, diction))
            datasets.append(sentence2vec_sample(new_sen, diction))
            labels.append(tag)
            sentences.append(sen)
    '''
    for sentence in pos_sen:
        new_sentence = []
        for l in sentence:
            if l in diction:
                new_sentence.append(word2index(l, diction))
        datasets.append(sentence2vec_sample(new_sentence, diction))
        labels.append(0) #正标签为0
        sentences.append(sentence)

    # 处理负向评论
    for sentence in neg_sen:
        new_sentence = []
        for l in sentence:
            if l in diction:
                new_sentence.append(word2index(l, diction))
        datasets.append(sentence2vec_sample(new_sentence, diction))
        labels.append(1) #负标签为1
        sentences.append(sentence)
    indices = np.random.permutation(len(datasets))
    datasets = [datasets[i] for i in indices]
    labels = [labels[i] for i in indices]
    sentences = [sentences[i] for i in indices]

3.3.3 划分数据集

训练集,校验集,测试集,比例一般为10:1:1

    # split train, test, verify datasets
    test_size = int(len(datasets) // 10)
    train_data = datasets[2 * test_size :]
    train_label = labels[2 * test_size : ]
    
    valid_data = datasets[: test_size]
    valid_label = labels[: test_size]
    
    test_data = datasets[test_size : 2 * test_size]
    test_label = labels[test_size : 2 * test_size]

3.3.4 建立神经网络

结构:输入层(约7000个样本),一个中间层(10个隐单元),输出层(2个单元)

注意,此处用的是relu,不是sigmoid。原因:虽然x>0时没处理数据,但x<0时为0让relu有了非线性的能力,且运算比sigmoid简单,速度更快,方便反向误差传导


def plot(records):
    print('plot begin')
    a = [i[0] for i in records]
    b = [i[1] for i in records]
    c = [i[2] for i in records]
    pyplot.plot(a, label='train loss')
    pyplot.plot(b, label='verify loss')
    pyplot.plot(c, label='valid accuracy')
    pyplot.xlabel('step')
    pyplot.ylabel('losses & accuracy')
    pyplot.legend()
    pyplot.show()


def main():
    model = torch.nn.Sequential(
        torch.nn.Linear(len(diction), 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 2),
        torch.nn.LogSoftmax(dim=1),
    )
    cost = torch.nn.NLLLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    records = []
    losses = []
    for epoch in range(3):
        for i, data in enumerate(zip(train_data, train_label)):
            x, y = data
            x = torch.tensor(x, requires_grad=True, dtype=torch.float).view(1, -1)
            y = torch.tensor(np.array([y]), dtype=torch.long)
            optimizer.zero_grad()
            predict = model(x)
            loss = cost(predict, y)
            losses.append(loss.data.numpy())
            loss.backward()
            optimizer.step()

        val_losses = []
        rights = []
        for j, val in enumerate(zip(valid_data, valid_label)):
            x, y = val
            x = torch.tensor(x, requires_grad=True, dtype=torch.float).view(1, -1)
            y = torch.tensor(np.array([y]), dtype = torch.long)
            predict = model(x)
            right = rightness_sample(predict, y)
            rights.append(right)
            loss = cost(predict, y)
            val_losses.append(loss.data.numpy())
        right_ratio = 1.0 * np.sum([i[0] for i in rights]) / np.sum([i[1] for i in rights])
        print(f'No.{epoch}, train loss:{np.mean(losses):.4f}, verify loss:{np.mean(val_losses):.4f}, verify accuracy:{right_ratio}')
        records.append([np.mean(losses), np.mean(val_losses), right_ratio])
    # plot
    plot(records)

3.3 运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/321510.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络期末复习(基础概念+三套卷子练习)

第一章&#xff1a;计算机网络概念 计算机网络的概念 计算机网络的定义 计算机网络是指将地理位置不同的 具有独立功能的多台计算机 及其外部设备&#xff0c;通过 通信线路链接 起来&#xff0c;在网络操作系统&#xff0c;网络管理软件及网络通信协议的管理和协调下&#…

【生存技能】git操作

先下载git https://git-scm.com/downloads 我这里是win64&#xff0c;下载了相应的直接安装版本 64-bit Git for Windows Setup 打开git bash 设置用户名和邮箱 查看设置的配置信息 获取本地仓库 在git bash或powershell执行git init&#xff0c;初始化当前目录成为git仓库…

精品量化公式——“区域突破”,应对当下行情较好的主图看盘策略

不多说&#xff0c;直接上效果如图&#xff1a; ► 日线表现 代码评估 技术指标代码评估&#xff1a; VAR1, VAR2, VAR3&#xff1a;这些变量是通过指数移动平均&#xff08;EMA&#xff09;计算得出的。EMA是一种常用的技术分析工具&#xff0c;用于平滑价格数据并减少市场“…

Ubuntu共享文件到win

Ubuntu共享文件到win 1、安装samba sudo apt-get install samba samba-common2、创建一个共享文件夹&#xff0c;并设置777权限 mkdir /home/qyh/share sudo chmod 777 /home/qyh/share我的用户名&#xff1a;qyh。 3、添加用户及密码 sudo smbpasswd -a qyh4、修改配置文…

操作系统复习篇

目录 一、引论 1.1、操作系统的目标 方便性&#xff1a; 有效性&#xff1a; 可扩充性&#xff1a; 开放性&#xff1a; 1.2、操作系统的作用 用户与计算机硬件系统之间的接口&#xff1a; 计算机系统资源的管理者&#xff1a; 实现对计算机资源的抽象&#xff1a; 1…

SqlAlchemy使用教程(二) 入门示例及编程步骤

SqlAlchemy使用教程(一) 原理与环境搭建SqlAlchemy使用教程(三) CoreAPI访问与操作数据库详解 二、入门示例与基本编程步骤 在第一章中提到&#xff0c;Sqlalchemy提供了两套方法来访问数据库&#xff0c;由于Sqlalchemy 官方文档结构有些乱&#xff0c;对于ORM的使用步骤的描…

使用scipy处理图片——旋转任意角度

大纲 载入图片左旋转30度&#xff0c;且重新调整图片大小右旋转30度&#xff0c;且重新调整图片大小左旋转135度&#xff0c;保持图片大小不变右旋转135度&#xff0c;保持图片大小不变 在《使用numpy处理图片——90度旋转》中&#xff0c;我们使用numpy提供的方法&#xff0c;…

庆祝一年的成长

本文字数&#xff1a;2288&#xff1b;估计阅读时间&#xff1a;6 分钟 作者&#xff1a;ClickHouse Team 审校&#xff1a;庄晓东&#xff08;魏庄&#xff09; 本文在公众号【ClickHouseInc】首发 随着今年即将结束&#xff0c;我们想要向您表达衷心的感谢&#xff0c;感谢您…

实现STM32烧写程序-(3) Hex文件结构

简介 要对STM32进行更新动作, 就需要对程序文件进行解析, 大部分编译的生成程序文件是Hex或者Bin, 先来看看Hex的结构吧。 资料 Hex文件 简介 Hex文件格式最早由Intel公司于1973年创建。它最初是为了在Intel 8080微处理器上存储和传输二进制数据而设计的。随后&#xff0c;Hex…

XCTF:Hidden-Message[WriteUP]

使用Wireshark打开文件 分析能分析的流&#xff0c;这里直接选择UDP流 分别有两段流&#xff0c;内容都是关于物理的 和flag没啥关系&#xff0c;只能从别的方面下手 分析&#xff1a;整个数据包&#xff0c;全部由UDP协议组成 其中发送IP和接收IP固定不变&#xff0c;数据长…

Leetcode3002. 移除后集合的最多元素数

Every day a Leetcode 题目来源&#xff1a;3002. 移除后集合的最多元素数 解法1&#xff1a;贪心 可以将数组去重后分为三个部分&#xff1a;nums1 独有的&#xff0c;nums2 独有的&#xff0c;nums1 与 nums2 共有的。 求集合 S 时&#xff1a; 先选择两个数组独有的。…

【软件测试】学习笔记-性能测试的基本方法与应用领域

这篇文章探讨并发用户数、响应时间和系统吞吐量这三个指标之间的关系和约束&#xff0c;性能测试七种常用方法&#xff0c;以及四大应用领域。 由于性能测试是一个很宽泛的话题&#xff0c;所以不同的人对性能测试的看法也不完全一样&#xff0c;同样一种方法可能也会有不同的…

第十五届蓝桥杯单片机组备赛——独立键盘矩阵键盘

文章目录 一、按键原理二、独立键盘&矩阵键盘2.1 独立按键2.2 矩阵键盘 一、按键原理 原理很简单&#xff0c;当我们没有按下SW2时&#xff0c;由于上拉电阻得作用&#xff0c;使得输入引脚得信号为高电平&#xff0c;当按下按键后&#xff0c;引脚直接接地&#xff0c;输入…

MySQL面试题 | 08.精选MySQL面试题

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

List集合遍历过程中修改元素(有坑)

说来惭愧&#xff0c;学 java 2年了&#xff0c;对 “Java是值传递” 这句话还没有理解它的精髓&#xff0c;以至于编程的时候出现了一些错误&#xff0c;这里记录一下。 一.问题再现 1.1将List集合中的每个字符串更改为其他值 1.2将List集合中的对象更改为其他对象 二.问题分…

【DolphinScheduler】datax读取hive分区表时,空分区、分区无数据任务报错问题解决

问题背景&#xff1a; 最近在使用海豚调度DolphinScheduler的Datax组件时&#xff0c;遇到这么一个问题&#xff1a;之前给客户使用海豚做的离线数仓的分层搭建&#xff0c;一直都运行好好的&#xff0c;过了个元旦&#xff0c;这几天突然在数仓做任务时报错&#xff0c;具体报…

【Rust日报】2024-01-12 将 Rust 引入 Git 项目

将 Rust 引入 Git 项目 去年年底的假期里&#xff0c;Taylor Blau 花了一些时间思考如何将 Rust 引入 Git 项目。 将 Rust 引入 Linux 内核的重要工作正在进行中。在他们既定的目标中&#xff0c;他认为有一些可能与 Git 项目相关&#xff1a; 由于语言的安全保证&#xff0c;内…

Jenkins+nexus

jiekins安装完成 1、安装java环境 [rootnexus ~]# tar -xf jdk-8u211-linux-x64.tar.gz -C /usr/local [rootnexus ~]# vim /etc/profile.d/java.sh JAVA_HOME/usr/local/jdk1.8.0_211 PATH$PATH:$JAVA_HOME/bin [rootnexus ~]# source /etc/profile.d/java.sh 必须要选择与n…

UE5 伤害数字跳出

学习视频 大体思路&#xff1a; 1.创建一个控件蓝图。 播放动画&#xff0c;K透明度&#xff0c;文本位置。 2.创建一个控件组件。 类默认值中设置。 3.在位置处创建组件实例 自定义事件略。

汇编语言与接口技术——AD转换及PWM控制

一、 实验要求 实验目的&#xff1a; 掌握SPI总线的使用方式掌握xpt2046 AD转换芯片的工作原理掌握SPI总线方式实现基于xpt2046的AD转换掌握PWM控制功率的方式 实验内容&#xff1a; 学习xpt2046 AD转换芯片的工作原理&#xff0c;利用SPI总线实现基于该芯片的AD转换&#x…