动手学深度学习——循环神经网络的简洁实现(代码详解)

文章目录

    • 循环神经网络的简洁实现
      • 1. 定义模型
      • 2. 训练与预测

循环神经网络的简洁实现

# 使用深度学习框架的高级API提供的函数更有效地实现相同的语言模型
import torch 
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

1. 定义模型

构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer

# 构造一个具有256个隐藏单元的单隐藏层的循环神经网络层rnn_layer
num_hiddens =256
rnn_layer = nn.RNN(len(vocab), num_hiddens)

使用张量初始化状态,形状为(隐藏层数,批量大小,隐藏单元数)

# 使用张量初始化状态,形状为(隐藏层数,批量大小,隐藏单元数)
state = torch.zeros((1, batch_size, num_hiddens))
state.shape

在这里插入图片描述
通过一个隐状态和一个输入,可以用更新后的隐状态计算输出。

# 通过一个隐状态和一个输入,可以用更新后的隐状态计算输出。
# rnn_layer的“输出”(Y)不涉及输出层的计算: 它是指每个时间步的隐状态,这些隐状态可以用作后续输出层的输入。
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

在这里插入图片描述为一个完整的循环神经网络模型定义了一个RNNModel类,rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层。

# 为一个完整的循环神经网络模型定义了一个RNNModel类
# rnn_layer只包含隐藏的循环层,我们还需要创建一个单独的输出层
#save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的(之后将介绍),num_directions应该是2,否则应该是1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)
        
    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state
        
        
    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return torch.zeros((self.num_directions * self.rnn.num_layers, 
                               batch_size, self.num_hiddens), 
                              device=device)
        else:
            # nn.LSTM以元组作为隐状态
            return (torch.zeros((
                self.num_directions * self.rnn.num_layers, 
                batch_size, self.num_hiddens), device=device),
                    torch.zeros((
                    self.num_directions * self.rnn.num_layers,
                    batch_size, self.num_hiddens), device=device))

2. 训练与预测

在训练模型之前,基于一个具有随机权重的模型进行预测。

# 在训练模型之前,基于一个具有随机权重的模型进行预测。
device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

在这里插入图片描述
使用之前的超参数调用train_ch8,并且使用高级API训练模型

# 使用之前的超参数调用train_ch8,并且使用高级API训练模型
num_epochs, lr = 500, 1
d2l.train_ch8(net, train_iter, vocab, lr, num_epochs, device)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/161639.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

黑马程序员 学成在线项目 第1章 项目介绍环境搭建v3.1

第1章 项目介绍&环境搭建v3.1 1.项目背景 1.1 在线教育市场环境 以下内容摘自艾瑞:2020年在线教育行业洞察:To B赛道篇_网络服务_艾瑞网 在线教育行业是一个有着极强的广度和深度的行业,从校内到校外;从早幼教到职业培训&…

【Python】逆向与爬虫的故事

目录 一、前言 二、爬虫 1、什么是爬虫? 2、Python 爬虫的主要工具 3、爬虫的基本流程 4、实例代码 三、逆向 1、什么是逆向? 2、Python 逆向的主要工具 3、逆向的基本流程 4、实例代码 四、总结 一、前言 随着互联网技术的发展&#xff0c…

RIP路由信息协议

RIP路由信息协议(Routing Information Protocol) 最先得到广泛应用的协议,最大优点是简单要求网络中的每个路由器都要维护一张表,表中记录了从它自己到其他每一个目的网络的距离RIP是应用层协议,它在传输层使用UDP,RIP报文作为UD…

2023.11.18 - hadoop之zookeeper分布式协调服务

1.zookeeper简介 ZooKeeper概念: Zookeeper是一个分布式协调服务的开源框架。本质上是一个分布式的小文件存储系统 ZooKeeper作用: 主要用来解决分布式集群中应用系统的一致性问题。 ZooKeeper结构: 采用树形层次结构,没有目录与文件之分,ZooKeeper树中的每个节点被…

HDMI之EDID析义篇

DisplayID Type X Video Timing Data Block 实例 F0 2A 10 93 FF 0E 6F 08 8F 10 93 7F 07 37 04 8F 10该数据来源于SHARP AQUOS-TVE23A 4K144Hz电视机的第3个EDID块(基于HF-EEODB)。 定义 解释 VTDB 1: 3840x2160 144.000009 Hz 16:9 333.216 kHz 1343.527000 MHz (RBv3,h…

STM32 HAL库函数HAL_SPI_Receive_IT和HAL_SPI_Receive的区别

背景 前段时间开发一个按键板驱动,该板用的STM32F103系列单片机,前任工程师用STM32CubeMX生成的工程,里面全是HAL库调用,我接手后,学习了下HAL库的用法,踩坑不少,特别是带IT后缀的函数&#xf…

深入了解Java 8 新特性:lambda表达式进阶

阅读建议 嗨,伙计!刷到这篇文章咱们就是有缘人,在阅读这篇文章前我有一些建议: 本篇文章大概7000多字,预计阅读时间长需要10分钟。本篇文章的实战性、理论性较强,是一篇质量分数较高的技术干货文章&#…

【GUI】-- 09 JComboBox JList、JTextField JPasswordField JTextArea

GUI编程 03 Swing 3.6 列表 下拉框 package com.duo.lesson06;import javax.swing.*; import java.awt.*;public class ComboBoxDemo01 extends JFrame {public ComboBoxDemo01() throws HeadlessException {Container contentPane getContentPane();JComboBox<Object&…

stable diffusion十七种controlnet详细使用方法总结

个人网站&#xff1a;https://tianfeng.space 前言 最近不知道发点什么&#xff0c;做个controlnet 使用方法总结好了&#xff0c;如果你们对所有controlnet用法&#xff0c;可能了解但是有点模糊&#xff0c;希望能对你们有用。 一、SD controlnet 我统一下其他参数&#…

如何去掉照片中多余路人?一分钟帮你搞定

在外出拍照时&#xff0c;可能会遇到一些不希望出现在照片中的路人&#xff0c;比如在旅游景点、公共场所或者街头拍摄时突然闯入镜头的人。这些路人的出现可能会破坏照片的整体氛围&#xff0c;影响照片的美观度。因此&#xff0c;需要使用一些方法去掉这些多余的路人&#xf…

mysql慢查询日志分析工具(pt-query-digest)

首先说下安装mysql自带的分析工具&#xff1a;mysqldumpslow mysqldumpslow -t 3 /var/lib/mysql/localhost-slow.log 因为mysqldumpslow看到的信息有限&#xff0c;只是获取语句的基础信息&#xff0c;并不是很全面。下面介绍一个功能很强大的分析工具。 pt-query-digest pt…

【Python3】【力扣题】303. 区域和检索 - 数组不可变

【力扣题】题目描述&#xff1a; 【Python3】代码&#xff1a; 1、解题思路&#xff1a;从列表中获取指定下标的所有元素&#xff0c;求和。 知识点&#xff1a;列表[start:end]&#xff1a;切片。从列表中获取起始下标start&#xff08;含&#xff09;到结束下标end&#xf…

解决:虚拟机远程连接失败

问题 使用FinalShell远程连接虚拟机的时候连接不上 发现 虚拟机用的VMware&#xff0c;Linux发行版是CentOs 7&#xff0c;发现在虚拟机中使用ping www.baidu.com是成功的&#xff0c;但是使用FinalShell远程连接不上虚拟机&#xff0c;本地网络也ping不通虚拟机&#xff0c…

腾讯云轻量级服务器和云服务器什么区别?轻量服务器是干什么用的

随着互联网的迅速发展&#xff0c;服务器成为了许多人必备的工具。然而&#xff0c;面对众多的服务器选择&#xff0c;我们常常会陷入纠结之中。在这篇文章中&#xff0c;我们将探讨轻量服务器和标准云服务器的区别&#xff0c;帮助您选择最适合自己需求的服务器。 腾讯云双十…

Swagger(3):Swagger入门案例

1 编写SpringBoot项目新建一个Rest请求控制器。package com.example.demo.controller;import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestMap…

Flask 接口

目录 前言 代码实现 简单接口实现 执行其它程序接口 携带参数访问接口 前言 有时候会想着开个一个接口来访问试试&#xff0c;这里就给出一个基础接口代码示例 代码实现 导入Flask模块&#xff0c;没安装Flask 模块需要进行 安装&#xff1a;pip install flask 使用镜…

腾讯云服务器怎么样好用吗?腾讯云服务器性能评测

近年来&#xff0c;腾讯云作为一家领先的云服务提供商&#xff0c;备受关注。尤其是最近两年&#xff0c;腾讯云在优惠活动上的力度非常大&#xff0c;被誉为良心云。其优惠政策吸引了越来越多的用户选择腾讯云作为他们的云服务提供商。 腾讯云双十一领9999代金券 https://111…

GamingTcUI.dll丢失修复,最全面的GamingTcUI.dll修复指南

热衷于电脑游戏的用户可能会在启动游戏时遇到这样的错误信息&#xff1a;"无法启动应用&#xff0c;因为找不到GamingTcUI.dll"。那么这个GamingTcUI.dll文件是什么&#xff1f;如何解决这个问题呢&#xff1f;我们将在本文中进行详细讲解。 一.GamingTcUI.dll是什么…

Activiti,Apache camel,Netflex conductor对比,业务选型

Activiti,Apache camel,Netflex conductor对比&#xff0c;业务选型 1.activiti是审批流&#xff0c;主要应用于人->系统交互&#xff0c;典型应用场景&#xff1a;请假&#xff0c;离职等审批 详情可见【精选】activti实际使用_activiti通过事件监听器实现的优势_记录点滴…