RNN:文本生成

文章目录

    • 一、完整代码
    • 二、过程实现
      • 2.1 导包
      • 2.2 数据准备
      • 2.3 字符分词
      • 2.4 构建数据集
      • 2.5 定义模型
      • 2.6 模型训练
      • 2.7 模型推理
    • 三、整体总结

采用RNN和unicode分词进行文本生成

一、完整代码

作者在文章开头地址中使用C++实现了这一过程,为了便于理解,这里我们使用python代码进行实现

# 完整代码在这里
import tensorflow as tf
import keras_nlp
import numpy as np

tokenizer = keras_nlp.tokenizers.UnicodeCodepointTokenizer(vocabulary_size=400)

# tokens - ids
ids = tokenizer(['Why are you so funny?', 'how can i get you'])

# ids - tokens
tokenizer.detokenize(ids)

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

# 准备数据
text = open('./shakespeare.txt', 'rb').read().decode(encoding='utf-8')
dataset = tf.data.Dataset.from_tensor_slices(tokenizer(text))
dataset = dataset.batch(64, drop_remainder=True)
dataset = dataset.map(split_input_target).batch(64)


input, ouput = dataset.take(1).get_single_element()

# 定义模型

d_model = 512
rnn_units = 1025

class CustomModel(tf.keras.Model):
    def __init__(self, vocabulary_size, d_model, rnn_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocabulary_size, d_model)
        self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)
        if return_state:
            return x, states
        else:
            return x

model = CustomModel(tokenizer.vocabulary_size(), d_model, rnn_units)

# 查看模型结构
model(input)
model.summary()

# 模型配置
model.compile(
    loss = tf.losses.SparseCategoricalCrossentropy(),
    optimizer='adam',
    metrics=['accuracy']
)

# 模型训练
model.fit(dataset, epochs=3)

# 模型推理
class InferenceModel(tf.keras.Model):
    def __init__(self, model, tokenizer):
        super().__init__(self)
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, inputs, length, return_states=False):
        inputs = inputs = tf.constant(inputs)[tf.newaxis]
        
        states = None
        input_ids = self.tokenizer(inputs).to_tensor()
        outputs = []
        for i in range(length):
            predicted_logits, states = model(inputs=input_ids, states=states, return_state=True)
            input_ids = tf.argmax(predicted_logits, axis=-1)
            outputs.append(input_ids[0][-1].numpy())

        outputs = self.tokenizer.detokenize(lst).numpy().decode('utf-8')
        if return_states:
            return outputs, states
        else:
            return outputs

infere = InferenceModel(model, tokenizer)


# 开始推理
start_chars = 'hello'
outputs = infere.generate(start_chars, 1000)
print(start_chars + outputs)

二、过程实现

2.1 导包

先导包tensorflow, keras_nlp, numpy

import tensorflow as tf
import keras_nlp
import numpy as np

2.2 数据准备

数据来自莎士比亚的作品 storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt;我们将其下载下来存储为shakespeare.txt

2.3 字符分词

这里我们使用unicode分词:将所有字符都作为一个词来进行分词

tokenizer = keras_nlp.tokenizers.UnicodeCodepointTokenizer(vocabulary_size=400)

# tokens - ids
ids = tokenizer(['Why are you so funny?', 'how can i get you'])

# ids - tokens
tokenizer.detokenize(ids)

2.4 构建数据集

利用tokenizertext数据构建数据集

def split_input_target(sequence):
    input_text = sequence[:-1]
    target_text = sequence[1:]
    return input_text, target_text

text = open('./shakespeare.txt', 'rb').read().decode(encoding='utf-8')
dataset = tf.data.Dataset.from_tensor_slices(tokenizer(text))
dataset = dataset.batch(64, drop_remainder=True)
dataset = dataset.map(split_input_target).batch(64)


input, ouput = dataset.take(1).get_single_element()

2.5 定义模型

d_model = 512
rnn_units = 1025

class CustomModel(tf.keras.Model):
    def __init__(self, vocabulary_size, d_model, rnn_units):
        super().__init__(self)
        self.embedding = tf.keras.layers.Embedding(vocabulary_size, d_model)
        self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True)
        self.dense = tf.keras.layers.Dense(vocabulary_size, activation='softmax')

    def call(self, inputs, states=None, return_state=False, training=False):
        x = inputs
        x = self.embedding(x)
        if states is None:
            states = self.gru.get_initial_state(x)
        x, states = self.gru(x, initial_state=states, training=training)
        x = self.dense(x, training=training)
        if return_state:
            return x, states
        else:
            return x

model = CustomModel(tokenizer.vocabulary_size(), d_model, rnn_units)

# 查看模型结构
model(input)
model.summary()

2.6 模型训练

model.compile(
    loss = tf.losses.SparseCategoricalCrossentropy(),
    optimizer='adam',
    metrics=['accuracy']
)

model.fit(dataset, epochs=3)

2.7 模型推理

定义一个InferenceModel进行模型推理配置;

class InferenceModel(tf.keras.Model):
    def __init__(self, model, tokenizer):
        super().__init__(self)
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, inputs, length, return_states=False):
        inputs = inputs = tf.constant(inputs)[tf.newaxis]
        
        states = None
        input_ids = self.tokenizer(inputs).to_tensor()
        outputs = []
        for i in range(length):
            predicted_logits, states = model(inputs=input_ids, states=states, return_state=True)
            input_ids = tf.argmax(predicted_logits, axis=-1)
            outputs.append(input_ids[0][-1].numpy())

        outputs = self.tokenizer.detokenize(lst).numpy().decode('utf-8')
        if return_states:
            return outputs, states
        else:
            return outputs

infere = InferenceModel(model, tokenizer)


start_chars = 'hello'
outputs = infere.generate(start_chars, 1000)
print(start_chars + outputs)

生成结果如下所示,感觉很差:

hellonofur us:
medous, teserwomador.
walled o y.
as
t aderemowate tinievearetyedust. manonels,
w?
workeneastily.
watrenerdores aner'shra
palathermalod, te a y, s adousced an
ptit: mamerethus:
bas as t: uaruriryedinesm's lesoureris lares palit al ancoup, maly thitts?
b veatrt
watyeleditenchitr sts, on fotearen, medan ur
tiblainou-lele priniseryo, ofonet manad plenerulyo
thilyr't th
palezedorine.
ti dous slas, sed, ang atad t,
wanti shew.
e
upede wadraredorenksenche:
wedemen stamesly ateara tiafin t t pes:
t: tus mo at
io my.
ane hbrelely berenerusedus' m tr;
p outellilid ng
ait tevadwantstry.
arafincara, es fody
'es pra aluserelyonine
pales corseryea aburures
angab:
sunelyothe: s al, chtaburoly o oonis s tioute tt,
pro.
tedeslenali: s 't ing h
sh, age de, anet: hathes: s es'tht,
as:
wedly at s serinechamai:
mored t.
t monatht t athoumonches le.
chededondirineared
t

er
p y
letinalys
ani
aconen,
t rs:
t;et, tes-
luste aly,
thonort aly one telus, s mpsantenam ranthinarrame! a
pul; bon
s fofuly

三、整体总结

RNN结合unicode分词能进行文本生成但是效果一言难尽!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/206052.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用vscode的remotessh插件远程连接的时候被要求重复输入密码

问题描述: 需要远程连接服务器,使用ssh,我用到的是vscode里面的remotessh插件。配置好config以后 HostHostNameUserPortIdentifyFile进入到了vscode的密码登录界面,但是一直被要求循环输入密码,很奇怪,去…

防爆执法记录仪、防爆智能安全帽助力海上钻井平台远程可视化监管平台建设

推动远程安全管理,海上钻井"视"界拓新—防爆执法记录仪与防爆智能安全帽的创新应用 在海上钻井作业领域,安全生产一直是萦绕在每一个业者心头的重大课题。由于环境的恶劣及作业的特殊性,一旦发生安全事故,其后果往往极…

CentOS 7安装Java 8

前言 这是我在这个网站整理的笔记,有错误的地方请指出,关注我,接下来还会持续更新。 作者:神的孩子都在歌唱 要在CentOS 7上安装Java 8,请按照以下步骤操作: 打开终端并以root身份登录。 更新系统软件包: …

Mybatis 的简单运用介绍

Mybatis 用于操作数据库 操作数据库肯定需要: 1.SQL语句 2.数据库对象和 java 对象的映射 接下来我们看看怎么使用 Mybatis 我们先搞一些数据库内容 然后将其这些内容和Java对象进行映射 再创建一个类实现 select * from 再写一个类证明上述代码是否可以实现 别忘了在appli…

中信建投在金融电于化期刊发布 DataOps 实践

文 ‖ 中信建投证券股份有限公司 马丽霞 高宇航 李可 许哲 李海伟 近年来,数据的分析和应用对各行各工业的业务模式和竞争形态进行重塑,而积极应对挑战和顺应时代变化是各个市场参与者的必选项。作为资本市场数字化转型的领航者,中信建投证券…

python动态加载内容抓取问题的解决实例

问题背景 在网页抓取过程中,动态加载的内容通常无法通过传统的爬虫工具直接获取,这给爬虫程序的编写带来了一定的技术挑战。腾讯新闻(https://news.qq.com/)作为一个典型的动态网页,展现了这一挑战。 问题分析 动态…

指针(2)

函数指针数组 函数指针数组是一个用来存放函数指针(地址)的数组。 如上图,是将两个函数指针存入数组中。如何写函数指针数组名呢?我们可以先写出函数指针类型int (*)(int,int)然后在(*)里面加上数组名[]即可。 指向函数指针数组…

面试题:什么是负载均衡?常见的负载均衡策略有哪些?

文章目录 一、负载均衡二、负载均衡模型分类三、CDN负载均衡四、LVS负载均衡4.1 LVS 支持的三种模式4.1.1 DR 模式4.1.2 TUN 模式4.1.3 NAT 模式 4.2 LVS 基于 Netfilter 的框架实现 五、负载均衡策略是什么六、常用负载均衡策略图解6.1 轮询6.2 加权轮询6.3 最少连接数6.4 最快…

Ubuntu使用Nginx部署前端项目——记录

安装nginx 依次执行以下两条命令进行安装: sudo apt-get update sudo apt-get install nginx通过查看版本号查看是否安装成功: nginx -v补充卸载操作: sudo apt-get remove nginx nginx-common sudo apt-get purge nginx nginx-common su…

解决:ModuleNotFoundError: No module named ‘xlrd‘

解决:ModuleNotFoundError: No module named ‘xlrd’ 文章目录 解决:ModuleNotFoundError: No module named xlrd背景报错问题报错翻译报错位置代码报错原因解决方法今天的分享就到此结束了 背景 在使用之前的代码时,报错: pin_r…

C 中的结构 - 存储、指针、函数和自引用结构

0. 结构体的内存分配 当声明某种类型的结构变量时,结构成员被分配连续(相邻)的内存位置。 struct student{char name[20];int roll;char gender;int marks[5];} stu1; 此处,内存将分配给name[20]、roll、gender和marks[5]。st1这…

11-30 JavaWeb

修改与删除操作 防止空指针异常 localhost:8080 -> 分页查询 修改流程:(先查后改(两个servlet)) 修改: 传用户id(用户id怎么得到 -> 循环一次得到一个user 对象 user对象里用user.getId()得到用户id) UpdateUserQueryServlet.java (…

「Verilog学习笔记」状态机-重叠序列检测

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 读入数据移位寄存,寄存后的数据与序列数做对比,相等则flag为1,不等则为0 timescale 1ns/1nsmodule sequence_test2(input wire clk ,in…

计网Lesson5 - MAC 地址与 ARP

文章目录 M A C MAC MAC 地址1. M A C MAC MAC 地址的格式 2. M A C MAC MAC 地址的获取3. A R P ARP ARP 协议4. A R P ARP ARP 缓存5. R A R P RARP RARP M A C MAC MAC 地址 1. M A C MAC MAC 地址的格式 每个网卡都有一个 6 6 6 字节的 M A C MAC MAC 地址 M A C…

最大公约数的C语言实现xdoj31

时间限制: 1 S 内存限制: 1000 Kb 问题描述: 最大公约数(GCD)指某几个整数共有因子中最大的一个,最大公约数具有如下性质, gcd(a,0)a gcd(a,1)1 因此当两个数中有一个为0时,gcd是不为0的那个整数&#xff…

ios 逆向分分析,某业帮逆向算法(一)

用到工具: 爱思助手CrackerXL(砸壳软件)越狱手机ida反汇编软件分析login 的sign 签名算法中自己写算法 已知我们32位,我们不妨猜测是md5 ,那我们试图使用CC_MD5 ,这个是ios 中的标准库, 我们使用frida-trace 注入hook一下,看看有没有 经过 是经过了这个函数,密码也是…

计算机服务器中了_locked勒索病毒如何处理,_locked勒索病毒解密数据恢复

网络技术的不断发展,给企业的生产生活提供了极大便利,越来越多的企业走向数字化办公时代,但网络的发展也为网络安全埋下隐患,网络安全威胁不断增加。近期,云天数据恢复中心陆续接到很多企业的求助,企业的计…

FO-like Transformation Oracle Cloning

参考文献: [RS91] Rackoff C, Simon D R. Non-interactive zero-knowledge proof of knowledge and chosen ciphertext attack[C]//Annual international cryptology conference. Berlin, Heidelberg: Springer Berlin Heidelberg, 1991: 433-444.[BR93] Bellare M…

网狐类源码游戏配置数据库数据(一键配置网狐数据库)

网狐类源码游戏配置数据库数据(一键配置网狐数据库) 一般拿到网狐的源码或组件,需要先附加或配置数据库,以下为全部需要更改数据的地方,这里以荣耀系列版本数据库为例: 1. 数据库设置 [RYPlatformDB].…

appium :输入框控件为android.view.View 时输入内容(如:验证码、密码输入框)

问题背景 输入密码的组件信息为&#xff1a;<android.view.View resource-id“com.qq.ac.android:id/pwd_input”> 由于输入框控件是android.view.View&#xff0c;不是android.widget.EditText&#xff0c;所以只能点击&#xff0c;而启动appium后&#xff0c;会将输入…