TensorFlow2实战-系列教程11:RNN文本分类3

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

6、构建训练数据

  • 所有的输入样本必须都是相同shape(文本长度,词向量维度等)
  • tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
  • tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本
def data_generator(f_path, params):
    with open(f_path,encoding='utf-8') as f:
        print('Reading', f_path)
        for line in f:
            line = line.rstrip()
            label, text = line.split('\t')
            text = text.split(' ')
            x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的ID
            if len(x) >= params['max_len']:#截断操作
                x = x[:params['max_len']]
            else:
                x += [0] * (params['max_len'] - len(x))#补齐操作
            y = int(label)
            yield x, y
  1. 定义一个生成器函数,传进来读数据的路径、和一些有限制的参数,params 在是一个字典,它包含了最大序列长度(max_len)、词到索引的映射(word2idx)等关键信息
  2. 打开文件
  3. 打印文件路径
  4. 遍历每行数据
  5. 获取标签和文本
  6. 文本按照空格分离出单词
  7. 获取当前句子的所有词对应的索引,for w in text取出这个句子的每一个单词,[params[‘word2idx’]取出params中对应的word2idx字典,.get(w, len(word2idx))从word2idx字典中取出该单词对应的索引,如果有这个索引则返回这个索引,如果没有则返回len(word2idx)作为索引,这个索引表示unknow
  8. 如果当前句子大于预设的最大句子长度
  9. 进行截断操作
  10. 如果小于
  11. 补充0
  12. 标签从str转换为int类型
  13. yield 关键字:用于从一个函数返回一个生成器(generator)。与 return 不同,yield 不会退出函数,而是将函数暂时挂起,保存当前的状态,当生成器再次被调用时,函数会从上次 yield 的地方继续执行,使用 yield 的函数可以在处理大数据集时节省内存,因为它允许逐个生成和处理数据,而不是一次性加载整个数据集到内存中

也就是说yield 会从上一次取得地方再接着去取数据,而return却不会

def dataset(is_training, params):
    _shapes = ([params['max_len']], ())
    _types = (tf.int32, tf.int32)
  
    if is_training:
        ds = tf.data.Dataset.from_generator(
            lambda: data_generator(params['train_path'], params),
            output_shapes = _shapes,
            output_types = _types,)
        ds = ds.shuffle(params['num_samples'])
        ds = ds.batch(params['batch_size'])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    else:
        ds = tf.data.Dataset.from_generator(
            lambda: data_generator(params['test_path'], params),
            output_shapes = _shapes,
            output_types = _types,)
        ds = ds.batch(params['batch_size'])
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  
    return ds
  1. 定义一个制作数据集的函数,is_training表示是否是训练,这个函数在验证和测试也会使用,训练的时候设置为True,验证和测试为False
  2. 当前shape值
  3. 1
  4. 是否在训练,如果是:
  5. 构建一个Dataset
  6. 传进我们刚刚定义的生成器函数,并且传进实际的路径和配置参数
  7. 输出的shape值
  8. 输出的类型
  9. 指定shuffle
  10. 指定 batch_size
  11. 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速
  12. 如果不是在训练,则:
  13. 验证和测试不同的就是路径不同,以及没有shuffle操作,其他都一样
  14. 最后把做好的Datasets返回回去

7、自定义网络模型

在这里插入图片描述
一条文本变成一组向量/矩阵的基本流程:

  1. 拿到一个英文句子
  2. 通过查语料表将句子变成一组索引
  3. 通过词嵌入表结合索引将每个单词都变成一组向量,一条句子就变成了一个矩阵,这就是特征了

在这里插入图片描述
在这里插入图片描述
BiLSTM即双向LSTM,就是在原本的LSTM增加了一个从后往前走的模块,这样前向和反向两个方向都各自生成了一组特征,把两个特征拼接起来得到一组新的特征,得到翻倍的特征。其他前面和后续的处理操作都是一样的。

class Model(tf.keras.Model):
    def __init__(self, params):
        super().__init__()
        self.embedding = tf.Variable(np.load('./vocab/word.npy'), dtype=tf.float32, name='pretrained_embedding', trainable=False,)
        self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])
        self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])
        self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])

        self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))
        self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))
        self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))
        self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])
        self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)
        self.out_linear = tf.keras.layers.Dense(2)  
    def call(self, inputs, training=False):
        if inputs.dtype != tf.int32:
            inputs = tf.cast(inputs, tf.int32)
        batch_sz = tf.shape(inputs)[0]
        rnn_units = 2*params['rnn_units']
        x = tf.nn.embedding_lookup(self.embedding, inputs)        
        x = self.drop1(x, training=training)
        x = self.rnn1(x)
        x = self.drop2(x, training=training)
        x = self.rnn2(x)
        x = self.drop3(x, training=training)
        x = self.rnn3(x)
        x = self.drop_fc(x, training=training)
        x = self.fc(x)
        x = self.out_linear(x)
        return x
  1. 自定义一个模型,继承tf.keras.Model模块
  2. 初始化函数
  3. 初始化
  4. 词嵌入,把之前保存好的词嵌入文件向量读进来
  5. 定义一层dropout1
  6. 定义一层dropout2
  7. 定义一层dropout3
  8. 定义一个rnn1,rnn_units表示得到多少维的特征,return_sequences表示是返回一个序列还是最后一个输出
  9. 定义一个rnn2,最后一层的rnn肯定只需要最后一个输出,前后两个rnn的堆叠肯定需要返回一个序列
  10. 定义一个rnn3 ,tf.keras.layers.LSTM()直接就可以定义一个LSTM,在外面再封装一层API:tf.keras.layers.Bidirectional就实现了双向LSTM
  11. 定义全连接层的dropout
  12. 定义一个全连接层,因为是双向的,这里就需要把参数乘以2
  13. 定义最后输出的全连接层,只需要得到是正例还是负例,所以是2
  14. 定义前向传播函数,传进来一个batch的数据和是否是在训练
  15. 如果输入数据不是tf.int32类型
  16. 转换成tf.int32类型
  17. 取出batch_size
  18. 设置LSTM神经元个数,双向乘以2
  19. 使用 TensorFlow 的 embedding_lookup 函数将输入的整数索引转换为词向量
  20. 数据通过第1个 Dropout 层
  21. 数据通过第1个rnn
  22. 数据通过第2个 Dropout 层
  23. 数据通过第2个rnn
  24. 数据通过第3个 Dropout 层
  25. 数据通过第3个rnn
  26. 经过全连接层对应的Dropout
  27. 数据通过一个全连接层
  28. 最后,数据通过一个输出层
  29. 返回最终的模型输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/359651.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

苍穹外卖项目可以写的简历和如何优化简历

文章目录 重点写中规写添加自己个性的项目面试会问道的问题 我是一名双非大二计算机本科生,希望我的分享对你有帮助,点赞关注不迷路。 简历编写一直是很多人求职人的心病,我自己上学期有一门课程是去校内企业面试,当时我就感受出…

【lesson26】学习MySQL事务前的基础知识

文章目录 CURD不加控制,会有什么问题?CURD满足什么属性,能解决上述问题?什么是事务?为什么会出现事务事务的版本支持 CURD不加控制,会有什么问题? CURD满足什么属性,能解决上述问题&…

使用Python的Turtle模块来绘制爱心图案

import turtle as t# 设置画布大小和颜色 t.setup(800, 600) t.bgcolor(white)# 设置画笔颜色和粗细 t.pensize(2) t.color(red)# 定义爱心函数 def heart():t.begin_fill()t.left(140)t.forward(224)for i in range(200):t.right(1)t.forward(2)t.left(120)for i in range(200…

Springboot+Redis

首先前提我们要在自己的本机电脑或者服务器上安装一个redis的服务器 Redis配置 添加依赖: <!-- SpringBoot Boot Redis --> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artif…

sklearn 计算 tfidf 得到每个词分数

from sklearn.feature_extraction.text import TfidfVectorizer# 语料库 可以换为其它同样形式的单词 corpus [list(range(-5, 5)),list(range(-6,4)),list(range(12)),list(range(13))]# corpus [ # [Two, wrongs, don\t, make, a, right, .], # [The, pen, is, might…

【小白学unity记录】使用unity播放声音

1. 示例 unity中播放声音涉及到两个组件。AudioSource和AudioClip。AudioSource可以理解为播放器&#xff0c;AudioClip可以理解为音频片段文件。AudioSource可以通过.clip属性切换音频片段。 using UnityEngine;public class PlayerController : MonoBehaviour {private int…

油分离器的介绍

压缩机的排气中带有冷冻机油&#xff0c;这些冷冻机油如果随制冷剂蒸汽进入冷凝器、蒸发器后将 在传热表面形成油膜&#xff0c;从而影响换热效果。因此通常在压缩机与冷凝器之间装设油分离器&#xff0c;用 来分离制冷剂蒸汽中挟带的冷冻机油。在氟利昂制冷系统中&#xff0c;…

搜索引擎评价指标及指标间的关系

目录 二分类模型的评价指标准确率(Accuracy,ACC)精确率(Precision,P)——预测为正的样本召回率(Recall,R)——正样本注意事项 P和R的关系——成反比F值F1值F值和F1值的关系 ROC&#xff08;Receiver Operating Characteristic&#xff09;——衡量分类器性能的工具AUC&#xff…

线性代数------矩阵的运算和逆矩阵

矩阵VS行列式 矩阵是一个数表&#xff0c;而行列式是一个具体的数&#xff1b; 矩阵是使用大写字母表示&#xff0c;行列式是使用类似绝对值的两个竖杠&#xff1b; 矩阵的行数可以不等于列数&#xff0c;但是行列式的行数等于列数&#xff1b; 1.矩阵的数乘就是矩阵的每个…

c学习:sqlite3数据库操作

目录 获取sqlite3源码 c调用步骤 常用接口函数说明 例子 打开数据库&#xff0c;新建表&#xff0c;插入数据&#xff0c;查询数据&#xff0c;关闭数据库 查询数据需要在回调函数中获取 获取sqlite3源码 先下载c的sqlite3源码&#xff0c;https://www.sqlite.org/inde…

【蓝桥杯冲冲冲】进阶搜索 Anya and Cubes

蓝桥杯备赛 | 洛谷做题打卡day22 文章目录 蓝桥杯备赛 | 洛谷做题打卡day22Anya and Cubes题面翻译输入格式输出题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 样例 #2样例输入 #2样例输出 #2 样例 #3样例输入 #3样例输出 #3 提示题解代码我的一些话 Anya and Cubes …

HBase表结构

HBase是非关系型数据库&#xff0c;是高可靠性、高性能、面向列、可伸缩、实时读写的分布式数据库。 HBase使用场景 大规模数据存储&#xff1a;如日志记录、数据库备份等。实时数据访问&#xff1a;如实时搜索、实时分析等。高性能读写&#xff1a;如高并发、低延迟的读写操…

鸿蒙南向开发——GN快速入门指南

运行GN(Generate Ninja) 运行gn&#xff0c;你只需从命令行运行gn&#xff0c;对于大型项目&#xff0c;GN是与源码一起的。 对于Chromium和基于Chromium的项目&#xff0c;有一个在depot_tools中的脚本&#xff0c;它需要加入到你的PATH环境变量中。该脚本将在包含当前目录的…

系统分析师-22年-下午题目

系统分析师-22年-下午题目 更多软考知识请访问 https://ruankao.blog.csdn.net/ 试题一必答&#xff0c;二、三、四、五题中任选其中两题作答 试题一 (25分) 说明 某软件公司拟开发一套博客系统&#xff0c;要求能够向用户提供一个便捷发布自已心得&#xff0c;及时有效的…

什么是事务?

事务 是一组操作的集合&#xff0c;它是一个不可分割的工作单位。事务会把所有的操作作为一个整体&#xff0c;一起向数据库提交或者是撤销操作请求。所以这组操作要么同时成功&#xff0c;要么同时失败。 1. 事务管理 怎么样来控制这组操作&#xff0c;让这组操作同时成功或…

故障诊断 | 一文解决,CNN卷积神经网络故障诊断(Matlab)

文章目录 效果一览文章概述专栏介绍源码设计参考资料效果一览 文章概述 故障诊断 | 一文解决,CNN卷积神经网络故障诊断(Matlab) 专栏介绍 订阅【故障诊断】专栏,不定期更新机器学习和深度学习在故障诊断中的应用;订阅

oracle错误:The Network Adapter could not establish the connection

执行请求的操作时遇到错误: IO 错误: The Network Adapter could not establish the connection (CONNECTION_IDU34sFBqOSayf4o4C6pwQ6A) 供应商代码 17002 原因&#xff1a; 错误代码 17002 表示 Oracle 数据库客户端遇到了网络适配器无法建立连接的问题 解决办法&#x…

ModelArts加速识别,助力新零售电商业务功能的实现

前言 如果说为客户提供最好的商品是产品眼中零售的本质&#xff0c;那么用户的思维是什么呢&#xff1f; 在用户眼中&#xff0c;极致的服务体验与优质的商品同等重要。 企业想要满足上面两项服务&#xff0c;关键在于提升效率&#xff0c;也就是需要有更高效率的零售&#…

Apache Flink文件上传漏洞(CVE-2020-17518)漏洞代码分析

漏洞复现参考如下文章 Apache Flink文件上传漏洞&#xff08;CVE-2020-17518&#xff09;漏洞复现分析_文件上传漏洞复现cve-CSDN博客 分析代码的话&#xff0c;首先找到漏洞修复的邮件 漏洞详情&#xff0c;可以看到漏洞概要&#xff0c;影响的版本&#xff0c;漏洞描述以及…

EasyExcel实现Excel文件导入导出

1 EasyExcel简介 EasyExcel是一个基于Java的简单、省内存的读写Excel的开源项目。在尽可能节约内存的情况下支持读写百M的Excel。 github地址: https://github.com/alibaba/easyexcel 官方文档: https://www.yuque.com/easyexcel/doc/easyexcel Excel解析流程图: EasyExcel读取…