PyTorch -- RNN 快速实践

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as plt
      
      import torch
      import torch.nn as nn
      import torch.optim as optim
      
      
      seq_len     = 50
      batch       = 1
      num_time_steps = seq_len
      
      input_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True 
      
      class Net(nn.Module):  ## model 定义
      	def __init__(self):
      		super(Net, self).__init__()
      		self.rnn = nn.RNN(
      		input_size=input_size,
      		hidden_size=hidden_size,
      		num_layers=num_layers,
      		batch_first=batch_first)
      		# for p in self.rnn.parameters():
      		# 	nn.init.normal_(p, mean=0.0, std=0.001)
      		self.linear = nn.Linear(hidden_size, output_size)
      
      	def forward(self, x, hidden_prev):
      		out, hidden_prev = self.rnn(x, hidden_prev)
      		# out: [batch, seq_len, hidden_size]
      		out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]
      		out = self.linear(out) 			 # [batch*seq_len, output_size]
      		out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]
      		return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01
      
      def tarin_RNN():
          model = Net()
          print('model:\n',model)
          criterion = nn.MSELoss()
          optimizer = optim.Adam(model.parameters(), lr)
          hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化h
      
          l = []
          for iter in range(100):  # 训练100次
              start = np.random.randint(10, size=1)[0]  ## 序列起点
              time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
              data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      
              x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
              y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点
      
              output, hidden_prev = model(x, hidden_prev)
              hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detach
      
              loss = criterion(output, y)
              model.zero_grad()
              loss.backward()
              optimizer.step()
      
              if iter % 100 == 0:
                  print("Iteration: {} loss {}".format(iter, loss.item()))
                  l.append(loss.item())
          #############################绘制损失函数#################################
          plt.plot(l,'r')
          plt.xlabel('训练次数')
          plt.ylabel('loss')
          plt.title('RNN LOSS')
          plt.savefig('RNN_LOSS.png')
          return hidden_prev,model
      
       hidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    
      
      predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):
          input = input.view(1, 1, 1)
          pred, hidden_prev = model(input, hidden_prev)
          input = pred  ## 循环获得每个input点输入网络
          predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"
    	p.grad.nomr()
    	torch.nn.utils.clip_grad_norm_(p, 10)  ## 其中的 norm 后面的_ 表示 in place
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/725035.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu 24.04安装zabbix7.0.0图形中文乱码

当zabbix安装完成后,设置中文界面时,打开图形,中文内容会显示方框乱码,是因为服务器字体中没有相关的中文字体,需要更换。 1、找到中文字体,可以在网络上下载《得意黑》开源字体,也可以在windo…

LeetCode322.零钱兑换(一)

LeetCode刷题记录 文章目录 📜题目描述💡解题思路⌨C代码 📜题目描述 给你一个整数数组 coins ,表示不同面额的硬币;以及一个整数 amount ,表示总金额。 计算并返回可以凑成总金额所需的 最少的硬币个数 。…

SAP MIGO 050 BADI:字段 GOITEM-XXXXX 未准备好输出

背景: MIGO过账时候需要根据某些条件更改某些字段的值,当要改的字段在前台不显示时,通过MB_MIGO_BADI~LINE_MODIFY去更改时,则会出现以下报错:MIGO050 解决方案1: 通过配置将该字段配置显示出来即可&…

阿里云如何部署项目【2024 详细版】

首次注册阿里云后可以购买免费服务器,可以用服务器练习部署项目,这里以部署个人网站为例 本人目前没有购买域名,因此域名流程并没有写,有看不懂的私信或者评论就行,我都可以看见 目录 一、购买服务器 二、安装宝塔…

「Python-docx 专栏」docx设置罗马数字页码,即页码编码格式为罗马数字

本文目录 前言一、docx 设置罗马数字页码1、docx设置大写罗马数字的页码①、docx背后的xml长啥样②、<w:sectPr> 标签详解③、通过<w:sectPr> 设置大写罗马数字的页码A、完整代码B、处理效果图C、这段代码实际上的作用2、docx设置小写罗马数字的页码①、完整代码②…

vue3前端对接后端的图片验证码

vue3前端对接后端的图片验证码 <template> <image :src"captchaUrl" alt"图片验证码" click"refreshCaptcha"></image> </template><script setup>import {ref} from "vue";import {useCounterStore} …

vue 2.0

自定义vue标签指令&#xff1a; <!DOCTYPE html> <html lang"en"> <script src"vue.js"></script> <head><meta charset"UTF-8"><title>Title</title> </head> <body> <div id…

Prometheus告警Alertmanager部署

Prometheus告警Alertmanager部署 资源监控一般离不开预警&#xff0c;因为我们不可能每时每刻都盯着某个资源监控看&#xff0c;而且在实际的工作中当中我们搭建的解决方案涉及到的服务器是多台甚至数十台&#xff0c;所以更加不现实&#xff0c;因此资源告警是一个必不可少的…

3ds Max软件下载安装:3D建模软件 轻松开启你的建模之旅!

3ds Max&#xff0c;在建模过程中&#xff0c;网格建模和NURBS建模两大技术发挥着不可或缺的作用。网格建模允许用户通过顶点、边和面等元素的调整&#xff0c;精确地塑造出模型的形态&#xff1b;而NURBS建模则以其优秀的曲线和曲面处理能力&#xff0c;为设计师们提供了更为平…

ChinaTravel成流量密码,景区如何打造视频监控管理平台提升旅游体验

随着中国经济的飞速发展和人民生活水平的持续提高&#xff0c;旅游已经成为越来越多人休闲放松的首选方式。近期&#xff0c;随着互联网的普及和社交媒体的兴起&#xff0c;以及免签政策带火入境游&#xff0c;“ChinaTravel”已成为社交网络上的一大流量密码&#xff0c;吸引了…

1. ELK日志分析

ELK日志分析 一、ELK作用、组件1、作用2、核心组件2.1 beat软件2.1 Logstash2.2 Elasticsearch2.3 Kibana 二、ELK部署、测试1、环境规划2、确保SELinux关闭、时间同步3、所有主机添加主机名解析4、三台ES主机安装jdk 1.155、调整系统资源限制6、部署es集群6.1 创建普通用户elk…

AI口语练习APP的技术难点

AI口语练习APP旨在帮助用户练习口语&#xff0c;因此其核心功能是语音识别和语音评测。以下是一些AI口语练习APP的主要技术难点。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作。 1. 语音识别 语音识别是将语音信号转换为文本的过程。…

C++ —— unordered_set、unordered_map的介绍及使用

目录 unordered系列关联式容器 unordered_set的介绍 unordered_set的使用 unordered_set的定义方式 unordered_set接口的使用 unordered_multiset unordered_map的介绍 unordered_map的使用 unordered_map的定义方式 unordered_map接口的使用 unordered_multimap …

机器学习周记(第四十三周:MCformer)2024.6.10~2024.6.16

目录 摘要ABSTRACT1 论文信息1.1 论文标题1.2 论文摘要1.3 论文引言1.4 论文贡献 2 论文模型2.1 问题定义2.2 可逆实例归一化&#xff08;Reversible Instance Normalization&#xff09;2.3 混合通道块 &#xff08;Mixed-Channels Block&#xff09;2.4 编码器&#xff08;De…

安全可靠跨国传输的前提下,如何兼顾数据跨国快速传输?

在全球化的商业环境中&#xff0c;跨国公司在与国际客户、合作伙伴或海外分支机构进行数据跨国快速传输时&#xff0c;不可避免会遇到一系列挑战。比如网络延迟、数据包丢失、带宽限制以及数据安全和合规性问题&#xff0c;一定程度上都会影响数据传输的效率&#xff0c;业务的…

项目的打包

一:打包到微信小程序 1)vscode打包 2)在微信小程序开发工具中打开路径,上传. 疑问:为什么pnpm bulid:mp-weixin用于打包,pnpm dev:mp-weixin也可生成对应路径下的文件?? 打包的是没有热重载,且打包体积更小. 二:条件编译 vscode可以打包成能在不同平台上运行的代码.但是有…

大数据关联规则算法

关联性&#xff08;Association&#xff09; 定义&#xff1a;指一个变量能够提供有关另一个变量的信息。特点&#xff1a;关联性是一个广泛的概念&#xff0c;它可以包括直接的、间接的、强的或弱的联系。 相关性&#xff08;Correlation&#xff09; 定义&#xff1a;指两个…

新手搭建Magic-API

项目场景&#xff1a; 我本是一个前端和GIS开发工程师&#xff0c;但新单位并没有配置完整的开发团队&#xff0c;确切说目前只有我一个人做开发&#xff0c;那么肯定避免不了要研究下后端。最近有一个小程序要开发&#xff0c;管理平台我直接用的fastAdminthinkphp写完了页面…

IAM风险CTF挑战赛

wiz启动了一个名为“The Big IAM Challenge”云安全CTF挑战赛。旨在让白帽子识别和利用 IAM错误配置&#xff0c;并从现实场景中学习&#xff0c;从而更好的认识和了解IAM相关的风险。比赛包括6个场景&#xff0c;每个场景都专注于各种AWS服务中常见的IAM配置错误。 Challenge…

企业工程图纸很多,应该如何进行图纸管理?

企业工程图纸很多&#xff0c;应该如何进行图纸管理&#xff1f; 设计制造企业在实际设计和生产过程中会产生大量的工程图纸&#xff0c;图纸一多管理起来就会十分麻烦&#xff0c;管理不当则是会影响整体的工作效率。对于大量工程图纸的管理&#xff0c;有多种方式方法来进行…