再用RNN神经网络架构设计生成式语言模型

上一篇:《用谷歌经典ML方法方法来设计生成式人工智能语言模型》

序言:市场上所谓的开源大语言模型并不完全开源,通常只提供权重和少量工具,而架构、训练数据集、训练方法及代码等关键内容并未公开。因此,要真正掌握人工智能模型,仍需从基础出发。本篇文章将通过传统方法重新构建一个语言模型,以帮助大家理解语言模型的本质:它并不神秘,主要区别在于架构设计。目前主流架构是谷歌在论文《Attention Is All You Need》中提出的 Transformer,而本文选择采用传统的 RNN(LSTM)方法构建模型,其最大局限在于不支持高效并行化,因而难以扩展。

创建模型

现在让我们创建一个简单的模型,用来训练这些输入数据。这个模型由一个嵌入层、一个 LSTM 层和一个全连接层组成。对于嵌入层,你需要为每个单词生成一个向量,因此参数包括总单词数和嵌入的维度数。在这个例子中,单词数量不多,所以用八个维度就足够了。

你可以让 LSTM 成为双向的,步数可以是序列的长度,也就是我们的最大长度减去 1(因为我们从末尾去掉了一个 token 用作标签)。

最后,输出层将是一个全连接层,其参数为总单词数,并使用 softmax 激活。该层的每个神经元表示下一个单词匹配相应索引值单词的概率:

model = Sequential()

model.add(Embedding(total_words, 8))

model.add(Bidirectional(LSTM(max_sequence_len-1)))

model.add(Dense(total_words, activation='softmax'))

使用分类损失函数(例如分类交叉熵)和优化器(例如 Adam)编译模型。你还可以指定要捕获的指标:

python

Copy code

model.compile(loss='categorical_crossentropy',

optimizer='adam', metrics=['accuracy'])

这是一个非常简单的模型,数据量也不大,因此可以训练较长时间,比如 1,500 个 epoch:

history = model.fit(xs, ys, epochs=1500, verbose=1)

经过 1,500 个 epoch 后,你会发现模型已经达到了非常高的准确率(见图 8-6)。

图 8-6:训练准确率

当模型的准确率达到 95% 左右时,我们可以确定,如果输入一段它已经见过的文本,它预测下一个单词的准确率约为 95%。然而请注意,当生成文本时,它会不断遇到以前没见过的单词,因此尽管准确率看起来不错,网络最终还是会迅速生成一些无意义的文本。我们将在下一节中探讨这个问题。

文本生成

现在你已经训练了一个可以预测序列中下一个单词的网络,接下来的步骤是给它一段文本序列,让它预测下一个单词。我们来看看具体怎么做。

预测下一个单词

首先,你需要创建一个称为种子文本(seed text)的短语。这是网络生成内容的基础,它会通过预测下一个单词来完成这一点。

从网络已经见过的短语开始,例如:

seed_text = "in the town of athy"

接下来,用 texts_to_sequences 对其进行标记化。即使结果只有一个值,这个方法也会返回一个数组,因此你需要取这个数组的第一个元素:

python

Copy code

token_list = tokenizer.texts_to_sequences([seed_text])[0]

然后,你需要对该序列进行填充,使其形状与用于训练的数据相同:

token_list = pad_sequences([token_list],

maxlen=max_sequence_len-1, padding='pre')

现在,你可以通过对该序列调用 model.predict 来预测下一个单词。这将返回语料库中每个单词的概率,因此需要将结果传递给 np.argmax 来获取最有可能的单词:

predicted = np.argmax(model.predict(token_list), axis=-1)

print(predicted)

这应该会输出值 68。如果查看单词索引表,你会发现这对应的单词是 “one”:

'town': 66, 'athy': 67, 'one': 68, 'jeremy': 69, 'lanigan': 70,

你可以通过代码在单词索引中搜索这个值,并将其打印出来:

for word, index in tokenizer.word_index.items():

if index == predicted:

print(word)

break

因此,从文本 “in the town of athy” 开始,网络预测下一个单词是 “one”。如果你查看训练数据,会发现这是正确的,因为这首歌的开头是:

In the town of Athy one Jeremy Lanigan

Battered away til he hadn’t a pound

现在你已经确认模型可以正常工作,可以尝试一些不同的种子文本。例如,当我使用种子文本 “sweet jeremy saw dublin” 时,模型预测的下一个单词是 “then”。(这段文本的选择是因为其中的所有单词都在语料库中。至少在开始时,如果种子文本中的单词在语料库中,你应该可以期待更准确的预测结果。)

通过递归预测生成文本

在上一节中,你已经学会了如何用模型根据种子文本预测下一个单词。现在,为了让神经网络生成新的文本,你只需重复预测的过程,每次添加新的单词即可。

例如,之前我使用短语 “sweet jeremy saw dublin”,模型预测下一个单词是 “then”。你可以在种子文本后添加 “then”,形成 “sweet jeremy saw dublin then”,然后进行下一次预测。重复这个过程,就可以生成一段由 AI 创造的文本。

以下是上一节代码的更新版,这段代码会循环执行多次,循环次数由 next_words 参数决定:

seed_text = "sweet jeremy saw dublin"

next_words = 10

for _ in range(next_words):

token_list = tokenizer.texts_to_sequences([seed_text])[0]

token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')

predicted = model.predict_classes(token_list, verbose=0)

output_word = ""

for word, index in tokenizer.word_index.items():

if index == predicted:

output_word = word

break

seed_text += " " + output_word

print(seed_text)

运行这段代码后,可能会生成类似这样的字符串:

sweet jeremy saw dublin then got there as me me a call doing me

文本很快会变得无意义。为什么会这样?

1. 训练文本量太小:文本语料库的规模非常有限,因此模型几乎没有足够的上下文来生成合理的结果。

2. 预测依赖性:序列中下一个单词的预测高度依赖于之前的单词。如果前几个单词的匹配度较差,即使最佳的“下一个”预测单词的概率也会很低。当将这个单词添加到序列后,再预测下一个单词时,其匹配概率进一步降低,因此最终生成的单词序列看起来会像随机拼凑的结果。

例如,虽然短语 “sweet jeremy saw dublin” 中的每个单词都在语料库中,但这些单词从未以这种顺序出现过。在第一次预测时,模型选择了概率最高的单词 “then”,其概率高达 89%。当把 “then” 添加到种子文本中,形成 “sweet jeremy saw dublin then” 时,这个新短语也从未在训练数据中出现过。因此,模型将概率最高的单词 “got”(44%)作为预测结果。继续向句子中添加单词会使新短语与训练数据的匹配度越来越低,从而导致预测的准确率下降,生成的单词序列看起来越来越随机。

这就是为什么 AI 生成的内容随着时间推移会变得越来越不连贯的原因。例如,可以参考优秀的科幻短片《Sunspring》。这部短片完全由一个基于 LSTM 的网络生成,类似于你正在构建的这个模型,它的训练数据是科幻电影剧本。模型通过种子内容生成新的剧本。结果令人捧腹:尽管开头的内容尚且可以理解,但随着情节推进,生成的内容变得越来越不可理喻。

总结:本篇中,我们成功构建了一个语言模型,虽然未采用当前流行的 Transformer 架构,但这一实践让我们深刻理解了人工智能模型的多样性。即使是当下耗资巨大的 Transformer 模型,也可能很快被更新的架构取代。因此,我们的目标是紧跟技术前沿,掌握模型设计的核心理念。下一篇我们将诗歌如何通过扩展数据集来提高模型的准确度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/934479.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

httprunner实践样例

目录 1. 安装 HTTPRunner 2. 基本概念和目录结构 3. 编写一个 HTTPRunner 测试用例(YAML 示例) 4. 运行测试用例 5. 使用 Python 编写测试用例 6. 运行 Python 测试用例 7. 集成测试报告 8. 高级用法:集成环境变量、外部数据 9. 集成…

没有在 SCM 配置或者插件中的 Git 存储库配置错误

问题: jenkins 配置新项目后首次运行报错如下,同时git代码分支无法选择。 已返回默认值 没有在 SCM 配置或者插件中的 Git 存储库配置错误 选项"使用仓库"设置为:"http://xxxx.git 请检查配置 原因: 配置pipeline 脚本时指…

HBU深度学习实验15-循环神经网络(2)

LSTM的记忆能力实验 飞桨AI Studio星河社区-人工智能学习与实训社区 (baidu.com) 长短期记忆网络(Long Short-Term Memory Network,LSTM)是一种可以有效缓解长程依赖问题的循环神经网络.LSTM 的特点是引入了一个新的内部状态&am…

算法日记 46 day 图论(并查集)

题目:冗余连接 108. 冗余连接 (kamacoder.com) 题目描述 有一个图,它是一棵树,他是拥有 n 个节点(节点编号1到n)和 n - 1 条边的连通无环无向图(其实就是一个线形图),如图&#xff…

二叉树优选算法(一)

一、根据二叉树创建字符串 题目介绍: 给你二叉树的根节点 root ,请你采用前序遍历的方式,将二叉树转化为一个由括号和整数组成的字符串,返回构造出的字符串。 空节点使用一对空括号对 "()" 表示,转化后需…

RabbitMq死信队列延迟交换机

架构图 配置 package com.example.demo.config;import org.springframework.amqp.core.*; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration;Configuration public class DeadLetterConfig {public String …

JavaWeb学习--cookie和session,实现登录的记住我和验证码功能

目录 (一)Cookie概述 1.什么叫Cookie 2.Cookie规范 3.Cookie的覆盖 4.cookie的最大存活时间 ​​​​​​(Cookie的生命) (二) Cookie的API 1.创建Cookie:new 构造方法 2.保存到客户端浏…

开启第二阶段---蓝桥杯

一、12.10--数据类型的范围及转化 今天是刚开始,一天一道题 对于这道题我想要记录的是Java中的整数默认是 int 类型,如果数值超出了 int 的范围,就会发生溢出错误。为了避免这个问题,可以将数字表示为 long 类型,方法…

黑马头条学习笔记

Day01-环境搭建 项目概述 课程大纲 业务说明 技术栈 Spring-Cloud-Gateway : 微服务之前架设的网关服务,实现服务注册中的API请求路由,以及控制流速控制和熔断处理都是常用的架构手段,而这些功能Gateway天然支持 运用Spring Boot快速开发…

详解RabbitMQ在Ubuntu上的安装

​​​​​​​ 目录 Ubuntu 环境安装 安装Erlang 查看Erlang版本 退出命令 ​编辑安装RabbitMQ 确认安装结果 安装RabbitMQ管理界面 启动服务 查看服务状态 通过IP:port访问 添加管理员用户 给用户添加权限 再次访问 Ubuntu 环境安装 安装Erlang RabbitMq需要…

SpringBoot【二】yaml、properties两配置文件介绍及使用

一、前言 续上一篇咱们已经搭建好了一个springboot框架雏形。但是很多初学的小伙伴私信bug菌说,在开发项目中,为啥.yaml的配置文件也能配置,SpringBoot 是提供了两种2 种全局的配置文件嘛,这两种配置有何区别,能否给大…

Excel的文件导入遇到大文件时

Excel的文件导入向导如何把已导入数据排除 入起始行,选择从哪一行开始导入。 比如,前两行已经导入了,第二次导入的时候排除前两行,从第三行开始,就将导入起始行设置为3即可,且不勾选含标题行。 但遇到大文…

Windows平台Unity3D下RTMP播放器低延迟设计探讨

技术背景 好多开发者希望我们分享下大牛直播SDK是如何在Unity下实现低延迟的RTMP播放的,以下是一些降低 Unity 中 RTMP 播放器延迟的方法: 一、选择合适的播放插件或工具 评估和选用专业的流媒体插件 市场上有一些专门为 Unity 设计的流媒体插件&…

PaddleOCR模型ch_PP-OCRv3文本检测模型研究(一)骨干网络

从源码上看,PaddleOCR一共支持四个版本,分别是PP-OCR、PP-OCRv2、PP-OCRv3、PP-OCRv4。本文选择PaddleOCR的v3版本的骨干网络作为研究对象,力图探究网络模型的内部结构。 文章目录 研究起点卷归层压发层残差层骨干网代码实验小结 研究起点 参…

log4j漏洞复现--vulhub

声明:学习过程参考了同站的B1g0rang大佬的文章 Web网络安全-----Log4j高危漏洞原理及修复(B1g0rang) CVE-2021-44228 RCE漏洞 Log4j 即 log for java(java的日志) ,是Apache的一个开源项目,通过使用Log4j,我们可以控制日志信息输…

计算机网络ENSP课设--三层架构企业网络

本课程设计搭建一个小型互联网,并模拟Internet的典型Web服务过程。通过此次课程设计,可以进一步理解Internet的工作原理和协议过程,并提高综合知识的运用能力和分析能力。具体目标包括: (1)掌握网络拓扑的…

SQL 获取今天的当月开始结束范围:

使用 GETDATE() 结合 DATEADD() 和 DATEDIFF() 函数来获取当前月的开始和结束时间范围。以下是实现当前月时间范围查询的 SQL&#xff1a; FDATE > DATEADD(MONTH, DATEDIFF(MONTH, 0, GETDATE()), 0) FDATE < DATEADD(MONTH, DATEDIFF(MONTH, 0, GETDATE()) 1, 0) …

Go的Gin比java的Springboot更加的开箱即用?

前言 隔壁组的云计算零零后女同事&#xff0c;后文简称 云女士 &#xff0c;非说 Go 的 Gin 框架比 Springboot 更加的开箱即用&#xff0c;我心想在 Java 里面 Springboot 已经打遍天下无敌手&#xff0c;这份底蕴岂是 Gin 能比。 但是云女士突出一个执拗&#xff0c;非我要…

【2024最新Java面试宝典】—— SpringBoot面试题(44道含答案)

1. 什么是 Spring Boot&#xff1f; Spring Boot 是 Spring 开源组织下的子项目&#xff0c;是 Spring 组件一站式解决方案&#xff0c;主要是简化了使用 Spring 的难度&#xff0c;简省了繁重的配置&#xff0c;提供了各种启动器&#xff0c;使开发者能快速上手。 2. 为什么…

【PlantUML系列】用例图(三)

目录 一、组成部分 二、典型案例 一、组成部分 参与者&#xff08;Actors&#xff09;&#xff1a;使用关键字 actor 后跟参与者的名称。用例&#xff08;Use Cases&#xff09;&#xff1a;使用关键字 usecase 后跟用例的名称和编号&#xff08;可选&#xff09;。系统边界…