Deep Learning Part Five RNNLM的学习和评价-24.4.30

准备好RNNLM所需要的层,我们现在来实现RNNLM,并对其进行训练,然后再评价一下它的结果的。

5.5.1 RNNLM的实现

这里我们将RNNLM使用的网络实现为SimpleRnnlm类,其层结构如下:

如图 5-30 所示,SimpleRnnlm 类是一个堆叠了 4 个 Time 层的神经网络。我们先来看一下初始化的代码:

import sys
sys.path.append('..')
import numpy as np
from common.time_layers import *

class SimpleRnnlm:
    def __init__(self, vocab_size, wordvec_size, hidden_size):
        V, D, H = vocab_size, wordvec_size, hidden_size
        rn = np.random.randn

        # 初始化权重
        embed_W = (rn(V, D) / 100).astype('f')
        rnn_Wx = (rn(D, H) / np.sqrt(D)).astype('f')
        rnn_Wh = (rn(H, H) / np.sqrt(H)).astype('f')
        rnn_b = np.zeros(H).astype('f')
        affine_W = (rn(H, V) / np.sqrt(H)).astype('f')
        affine_b = np.zeros(V).astype('f')

        # 生成层
        self.layers = [
            TimeEmbedding(embed_W),
            TimeRNN(rnn_Wx, rnn_Wh, rnn_b, stateful=True),
            TimeAffine(affine_W, affine_b)
        ]
        self.loss_layer = TimeSoftmaxWithLoss()
        self.rnn_layer = self.layers[1]

        # 将所有的权重和梯度整理到列表中
        self.params, self.grads = [], []
        for layer in self.layers:
            self.params += layer.params
            self.grads += layer.grads

拓展:

接着,我们来实现 forward() 方法、backward() 方法和 reset_state() 方法。

def forward(self, xs, ts):
    for layer in self.layers:
        xs = layer.forward(xs)
    loss = self.loss_layer.forward(xs, ts)
    return loss

def backward(self, dout=1):
    dout = self.loss_layer.backward(dout)
    for layer in reversed(self.layers):
        dout = layer.backward(dout)
    return dout

def reset_state(self):
    self.rnn_layer.reset_state()

从上述中,可以看出实现非常简单。在各个层中,正向传播和反向传播都正确地进行了实现。因此,我们只要以正确的顺序调用 forward()(或者 backward())即可。方便起见,这里将重设网络状态的方法实现为 reset_state()。以上就是对 SimpleRnnlm 类的说明。

5.5.3 RNNLM的学习代码

下面,我们使用 PTB 数据集进行学习,不过这里仅使用 PTB 数据集(训练数据)的前 1000 个单词。这是因为在本节实现的 RNNLM 中,即便使用所有的训练数据,也得不出好的结果。下一章我们将对它进行改进。

import sys
sys.path.append('..')
import matplotlib.pyplot as plt
import numpy as np
from common.optimizer import SGD
from dataset import ptb
from simple_rnnlm import SimpleRnnlm


# 设定超参数
batch_size = 10
wordvec_size = 100
hidden_size = 100 # RNN的隐藏状态向量的元素个数
time_size = 5 # Truncated BPTT的时间跨度大小
lr = 0.1
max_epoch = 100

# 读入训练数据(缩小了数据集)
corpus, word_to_id, id_to_word = ptb.load_data('train')
corpus_size = 1000
corpus = corpus[:corpus_size]
vocab_size = int(max(corpus) + 1)

xs = corpus[:-1] # 输入
ts = corpus[1:] # 输出(监督标签)
data_size = len(xs)
print('corpus size: %d, vocabulary size: %d' % (corpus_size, vocab_size))

# 学习用的参数
max_iters = data_size // (batch_size * time_size)
time_idx = 0
total_loss = 0
loss_count = 0
ppl_list = []

# 生成模型
model = SimpleRnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)

# ❶ 计算读入mini-batch的各笔样本数据的开始位置
jump = (corpus_size - 1) // batch_size
offsets = [i * jump for i in range(batch_size)]

for epoch in range(max_epoch):
    for iter in range(max_iters):
        # ❷ 获取mini-batch
        batch_x = np.empty((batch_size, time_size), dtype='i')
        batch_t = np.empty((batch_size, time_size), dtype='i')
        for t in range(time_size):
            for i, offset in enumerate(offsets):
                batch_x[i, t] = xs[(offset + time_idx) % data_size]
                batch_t[i, t] = ts[(offset + time_idx) % data_size]
            time_idx += 1

        # 计算梯度,更新参数
        loss = model.forward(batch_x, batch_t)
        model.backward()
        optimizer.update(model.params, model.grads)
        total_loss += loss
        loss_count += 1

    # ❸ 各个epoch的困惑度评价
    ppl = np.exp(total_loss / loss_count)
    print('| epoch %d | perplexity %.2f'
          % (epoch+1, ppl))
    ppl_list.append(float(ppl))
    total_loss, loss_count = 0, 0

只摘录了核心:

...
from common.trainer import RnnlmTrainer

...
model = SimpleRnnlm(vocab_size, wordvec_size, hidden_size)
optimizer = SGD(lr)
trainer = RnnlmTrainer(model, optimizer)

trainer.fit(xs, ts, max_epoch, batch_size, time_size)

如上所示,首先使用 model 和 optimizer 初始化 RnnlmTrainer 类,然后调用 fit(),完成学习。此时,RnnlmTrainer 类的内部将执行上一节进行的一系列操作,具体如下所示。

  • 按顺序生成 mini-batch
  • 调用模型的正向传播和反向传播
  • 使用优化器更新权重
  • 评价困惑度

使用Trainer的好处:

 使用 RnnlmTrainer 类,可以避免每次写重复的代码。本书的剩余部分都将使用 RnnlmTrainer 类学习 RNNLM。

5.6 小结

本章的主题是 RNN。RNN 通过数据的循环,从过去继承数据并传递到现在和未来。如此,RNN 层的内部获得了记忆隐藏状态的能力。本书中我们花了很多时间说明 RNN 层的结构,并实现了 RNN 层(和 Time RNN 层)。

本章还利用 RNN 创建了语言模型。语言模型给单词序列赋概率值。特别地,条件语言模型从已经出现的单词序列计算下一个将要出现的单词的概率。通过构成利用了 RNN 的神经网络,理论上无论多么长的时序数据,都可以将它的重要信息记录在 RNN 的隐藏状态中。但是,在实际问题中,这样一来,许多情况下学习将无法顺利进行。下一章我们将指出 RNN 存在的问题,并研究替代 RNN 的 LSTM 层或 GRU 层。这些层在处理时序数据方面非常重要,被广泛用于前沿研究。

本章所学的内容

  • RNN 具有环路,因此可以在内部记忆隐藏状态
  • 通过展开 RNN 的循环,可以将其解释为多个 RNN 层连接起来的神经网络,可以通过常规的误差反向传播法进行学习(= BPTT)
  • 在学习长时序数据时,要生成长度适中的数据块,进行以块为单位的 BPTT 学习(= Truncated BPTT)
  • Truncated BPTT 只截断反向传播的连接
  • 在 Truncated BPTT 中,为了维持正向传播的连接,需要按顺序输入数据
  • 语言模型将单词序列解释为概率
  • 理论上,使用 RNN 层的条件语言模型可以记忆所有已出现单词的信息

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/590329.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

调教AI给我写了一个KD树的算法

我不擅长C,但是目前需要用C写一个KD树的算法。首先我有一份点云数据,需要找给定坐标范围0.1mm内的所有点。 于是我开始问AI,他一开始给的答案,完全是错误的,但是我一步步给出反馈,告诉他的问题,…

基于Springboot的交流互动系统

基于SpringbootVue的交流互动系统的设计与实现 开发语言:Java数据库:MySQL技术:SpringbootMybatis工具:IDEA、Maven、Navicat 系统展示 用户登录 首页 帖子信息 聚会信息 后台登录 后台管理首页 用户管理 帖子分类管理 帖子信息…

【模板】二维前缀和

原题链接:登录—专业IT笔试面试备考平台_牛客网 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 二维前缀和板题。 二维前缀和:pre[i][j]a[i][j]pre[i-1][j]pre[i][j-1]-pre[i-1][j-1]; 子矩阵 左上角为(x1,y1) 右下角(x2,y2…

自然语言处理基础

文章目录 一、基础与应用简单介绍基本任务重要应用 二、词表示与语言模型词表示方案一:用一组的相关词来表示当前词方案二:one-hot representation,将每一个词表示成一个独立的符号方案三:上下文表示法(contextual rep…

Mamba3D革新3D点云分析:超越Transformer,提升本地特征提取效率与性能!

DeepVisionary 每日深度学习前沿科技推送&顶会论文分享,与你一起了解前沿深度学习信息! Mamba3D革新3D点云分析:超越Transformer,提升本地特征提取效率与性能! 引言:3D点云分析的重要性与挑战 3D点云…

Python语言零基础入门——文件

目录 一、文件的基本概念 1.文件 2.绝对路径与相对路径 3.打开文件的模式 二、文件的读取 三、文件的追加 四、文件的写入 五、with语句 六、csv文件 1.csv文件的读取 2.csv文件的写入 七、练习题:实现日记本 一、文件的基本概念 1.文件 文件是以计算…

【Android学习】简易计算器的实现

1.项目基础目录 新增dimens.xml 用于控制全部按钮的尺寸。图片资源放在drawable中。 另外 themes.xml中原来的 <style name"Theme.Learn" parent"Theme.MaterialComponents.DayNight.DarkActionBar">变为了&#xff0c;加上后可针对button中增加图片…

最新AI创作系统,ChatGPT商业运营系统网站源码,SparkAi-v6.5.0,Ai绘画/GPTs应用,文档对话

一、文章前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;那么如何搭建部署AI创作ChatGPT&#xff1f;小编这里写一个详细图文教程吧。已支持…

【C语言的完结】:最后的测试题

看到这句话的时候证明&#xff1a; 此刻你我都在努力~ 个人主页&#xff1a; Gu Gu Study ​​ 专栏&#xff1a;语言的起点-----C语言 喜欢的一句话&#xff1a; 常常会回顾努力的自己&#xff0c;所以要为自己的努力留下足迹…

Delta lake with Java--数据增删改查

之前写的关于spark sql 操作delta lake表的&#xff0c;总觉得有点混乱&#xff0c;今天用Java以真实的数据来进行一次数据的CRUD操作&#xff0c;所涉及的数据来源于Delta lake up and running配套的 GitGitHub - benniehaelen/delta-lake-up-and-running: Companion reposito…

软件无线电系列——信道编译码

微信公众号上线&#xff0c;搜索公众号小灰灰的FPGA,关注可获取相关源码&#xff0c;定期更新有关FPGA的项目以及开源项目源码&#xff0c;包括但不限于各类检测芯片驱动、低速接口驱动、高速接口驱动、数据信号处理、图像处理以及AXI总线等 本节目录 一、信道编译码 1、数字…

开源的贴吧数据查询工具

贴吧数据查询工具 这是一个贴吧数据查询工具&#xff0c;目前仍处于开发阶段。 本地运行 要本地部署这个项目&#xff0c;请 克隆这个仓库并前往项目目录 git clone https://github.com/Dilettante258/tieba-tools.git cd tieba-tools安装依赖 pnpm install运行项目 np…

服务器数据恢复—异常断电导致RAID模块故障的数据恢复案例

服务器数据恢复环境&#xff1a; 某品牌ProLiant DL380系列服务器&#xff0c;服务器中有一组由6块SAS硬盘组建的RAID5阵列&#xff0c;WINDOWS SERVER操作系统&#xff0c;作为企业内部文件服务器使用。 服务器故障&#xff1a; 机房供电几次意外中断&#xff0c;服务器出现故…

RMQ从入门到精通

一.概述与安装 //RabbitMQ //1.核心部分-高级部分-集群部分 //2.什么是MQ 消息队列message queue 先入先出原则;消息通信服务 //3.MQ的大三功能 流量消峰 应用解耦 消息中间件 //&#xff08;1&#xff09;人-订单系统(1万次/S)—> 人 - MQ(流量消峰,对访问人员进行排队) -…

Java 【数据结构】常见排序算法实用详解(上) 插入排序/希尔排序/选择排序/堆排序【贤者的庇护】

登神长阶 上古神器-常见排序算法 插入排序/选择排序/堆排序 &#x1f4d4; 一.排序算法 &#x1f4d5;1.排序的概念 排序 &#xff1a;所谓排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&a…

【Python】函数设计

1.联系函数的设计 2.找质数 3.找因子 4.判断水仙花数 5.斐波拉契数列递归调用&#xff0c;并用数组存储已计算过的数&#xff0c;减少重复计算 1、计算利息和本息 编写两个函数分别按单利和复利计算利息,根据本金、年利率、存款年限得到本息和和利息。调用这两个函数计算1…

学习Rust的第22天:mini_grep第2部分

书接上文&#xff0c;在本文中&#xff0c;我们学习了如何通过将 Rust 程序的逻辑移至单独的库箱中并采用测试驱动开发 (TDD) 实践来重构 Rust 程序。通过在实现功能之前编写测试&#xff0c;我们确保了代码的可靠性。我们涵盖了基本的 Rust 概念&#xff0c;例如错误处理、环境…

【linux-汇编-点灯之思路-程序】

目录 1. ARM汇编中的一些注意事项2. IMXULL汇编点灯的前序&#xff1a;3. IMXULL汇编点灯之确定引脚&#xff1a;4. IMXULL汇编点灯之引脚功能编写&#xff1a;4.1 第一步&#xff0c;开时钟4.2 第二步&#xff0c;定功能&#xff08;MUX&#xff09;4.3 第三步&#xff0c;定电…

【笔试训练】day17

1.小乐乐该数字 遇到按位处理的情况可以考虑用字符串去读 代码&#xff1a; #define _CRT_SECURE_NO_WARNINGS 1 #include <iostream> #include<string> using namespace std;int main() {string str;cin >> str;int ans 0;for (int i 0; i < str.siz…

JavaEE 初阶篇-深入了解 Junit 单元测试框架和 Java 中的反射机制(使用反射做一个简易版框架)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 Junit 单元测试框架概述 1.1 使用 Junit 框架进行测试业务代码 1.2 Junit 单元测试框架的常用注解&#xff08;Junit 4.xxx 版本&#xff09; 2.0 反射概述 2.1 获…