人工智能|机器学习——循环神经网络的简洁实现

循环神经网络的简洁实现

如何使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型。 我们仍然从读取时光机器数据集开始。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
 
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

定义模型

高级API提供了循环神经网络的实现。 我们构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer。 事实上,我们还没有讨论多层循环神经网络的意义。 现在仅需要将多层理解为一层循环神经网络的输出被用作下一层循环神经网络的输入就足够了。

num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

我们使用张量来初始化隐状态,它的形状是(隐藏层数,批量大小,隐藏单元数)。

state = torch.zeros((1, batch_size, num_hiddens))
state.shape

torch.Size([1, 32, 256])

通过一个隐状态和一个输入,我们就可以用更新后的隐状态计算输出。 需要强调的是,rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256]))

我们为一个完整的循环神经网络模型定义了一个RNNModel类。 注意,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)
 
    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state
 
    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return  torch.zeros((self.num_directions * self.rnn.num_layers,
                                 batch_size, self.num_hiddens),
                                device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers,
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                        self.num_directions * self.rnn.num_layers,
                        batch_size, self.num_hiddens), device=device))

 训练与预测

在训练模型之前,让我们基于一个具有随机权重的模型进行预测。

device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

 很明显,这种模型根本不能输出好的结果。 接下来,我们使用定义的超参数调用train_ch8,并且使用高级API训练模型。 

num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

perplexity 1.3, 404413.8 tokens/sec on cuda:0 time travellerit would be remarkably convenient for the historia travellery of il the hise fupt might and st was it loflers

由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/188456.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构实验】排序(一)冒泡排序改进算法 Bubble及其性能分析

文章目录 1. 引言2. 冒泡排序算法原理2.1 传统冒泡排序2.2 改进的冒泡排序 3. 实验内容3.1 实验题目(一)输入要求(二)输出要求 3.2 算法实现 4. 实验结果5. 实验结论 1. 引言 排序算法是计算机科学中一个重要而基础的研究领域&…

⑦【Redis GEO 】Redis常用数据类型:GEO [使用手册]

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ Redis GEO ⑦Redis GEO 基本操作命令1.geoadd …

BGP笔记全

自治系统---AS 定义:由一个单一的机构或者组织所管理的一系列IP网络及其设备所构成的集合。 AS划分的原因 如果整张网络很大,路由数量进一步增加,路由表规模变得太大,会导致路由收敛速度变慢,设备性能消耗加大&#…

paho mqtt的keepAliveInterval

一、keepAliveInterval 所用的版本为1.3.12 实验一、 这个值设置的30,打开mqtt的trace,发现每隔33s发送一次pingreq note: 期间,client和server一直保持qos0的消息交互(client->server) 实验二、 …

activiti流程回退与跳转

学习连接 【工作流Activiti7】3、Activiti7 回退与会签 【工作流Activiti7】4、Activiti7 结束/终止流程 Activiti-跳转到指定节点、回退 ativiti6.0 流程节点自由跳转实现、拒绝/不同意/返回上一节点、流程撤回、跳转、回退等操作(通用实现,亲测可用…

1panel可视化Docker面板安装与使用

官网地址1Panel - 现代化、开源的 Linux 服务器运维管理面板 文章目录 目录 文章目录 前言 一、环境准备 二、使用步骤 1.安装命令 2.一些命令 3.使用 总结 前言 一、环境准备 虚拟机centos 已经安装好docker和 Docker Compose 或者都没安装 1panel会帮你自动安装 二、使用…

使用YOLOV8 CLI训练自己的数据集

YOLOV8现在可以直接通过命令行工具运行训练, 推理过程了, 方法如下, 首先安装ultralytics的包: pip install ultralytics接着尝试使用yolov8n来简单做个推理: yolo taskdetect modepredict modelyolov8n.pt conf0.25 sourcesome_picture.jpeg接下来我们使用一个安全防护, 包括…

【SpringCloud】设计原则之单一职责与服务拆分

一、设计原则之单一职责 设计原则很重要的一点就是简单,单一职责也就是所谓的专人干专事 一个单元(一个类、函数或微服务)应该有且只有一个职责 无论如何,一个微服务不应该包含多于一个的职责 职责单一的后果之一就是职责单…

【数据结构实验】图(二)将邻接矩阵存储转换为邻接表存储

文章目录 1. 引言2. 邻接表表示图的原理2.1 有向权图2.2 无向权图2.3 无向非权图2.1 有向非权图 3. 实验内容3.1 实验题目(一)数据结构要求(二)输入要求(三)输出要求 3.2 算法实现 4. 实验结果 1. 引言 图是…

软件测试 | MySQL 主键约束详解:保障数据完整性与性能优化

📢专注于分享软件测试干货内容,欢迎点赞 👍 收藏 ⭐留言 📝 如有错误敬请指正!📢交流讨论:欢迎加入我们一起学习!📢资源分享:耗时200小时精选的「软件测试」资…

【C指针(五)】6种转移表实现整合longjmp()/setjmp()函数和qsort函数详解分析模拟实现

🌈write in front :🔍个人主页 : 啊森要自信的主页 ✏️真正相信奇迹的家伙,本身和奇迹一样了不起啊! 欢迎大家关注🔍点赞👍收藏⭐️留言📝>希望看完我的文章对你有小小的帮助&am…

Keil5个性化设置及常用快捷键

Keil5个性化设置及常用快捷键 1.概述 这篇文章是Keil工具介绍的第三篇文章,主要介绍下Keil5优化配置,以及工作中常用的快捷键提高开发效率。 第一篇:《安装嵌入式单片机开发环境Keil5MDK以及整合C51开发环境》https://blog.csdn.net/m0_380…

⑧【HyperLoglog】Redis数据类型:HyperLoglog [使用手册]

个人简介:Java领域新星创作者;阿里云技术博主、星级博主、专家博主;正在Java学习的路上摸爬滚打,记录学习的过程~ 个人主页:.29.的博客 学习社区:进去逛一逛~ Redis HyperLoglog ⑧Redis HyperLoglog基本操…

如何把自己银行卡里的钱转账充值到自己支付宝上?

原文来源:https://www.caochai.com/article-4524.html 支付宝余额是支付宝核心功能之一,主要用于网购支付、线下支付、转账等场景。用户可以将银行卡、余额宝等资金转入或转出至支付宝余额,实现快速转账和支付。 如何把自己银行卡里的钱转账…

用Python进行数据分析:探索性数据分析的实践与技巧(文末送书)

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

停车管理系统

1 用户信息管理 2 车位信息管理 3 车位费用设置 4 停泊车辆查询 5 车辆进出管理 6 用户个人中心 7 预定停车位 8 缴费信息 9 业务逻辑详解 1 用户停车:user用户登录,在预定停车位菜单,选择一个车位点击预定即可 2 车辆驶出:admin…

【数据结构实验】排序(二)希尔排序算法的详细介绍与性能分析

文章目录 1. 引言2. 希尔排序算法原理2.1 示例说明2.2 时间复杂性分析 3. 实验内容3.1 实验题目(一)输入要求(二)输出要求 3.2 算法实现3.3 代码解析3.4 实验结果 4. 实验结论 1. 引言 排序算法在计算机科学中扮演着至关重要的角色…

Python武器库开发-前端篇之CSS元素(三十二)

前端篇之CSS元素(三十二) CSS 元素是一个网页中的 HTML 元素,包括标签、类和 ID。它们可以通过 CSS 选择器选中并设置样式属性,以使网页呈现具有吸引力和良好的可读性。常见的 HTML 元素包括 div、p、h1、h2、span 等,它们可以使用 CSS 设置…

王者农药小游戏

游戏运行如下: sxt Background package sxt;import java.awt.*; //背景类 public class Background extends GameObject{public Background(GameFrame gameFrame) {super(gameFrame);}Image bg Toolkit.getDefaultToolkit().getImage("C:\\Users\\24465\\D…

图的邻接矩阵,邻接表的C语言实现(408真题)

图的邻接矩阵 数据结构定义 #define MAXV 50;//顶点数目的最大值 typedef struct{int vex[MAX]; //顶点表 int edge[MAXV][MAXV]; //邻接矩阵 int edgeNum,vexNum; //图中实际的边数和顶点数 }MGraph;初始化 void Matrix_Init(MGraph *Mgraph) {int v1, v2;//存储有边的…