时序预测demo 代码快速实现 MLP效果比LSTM 好,简单模拟数据

【PyTorch修炼】用pytorch写一个经常用来测试时序模型的简单常规套路(LSTM多步迭代预测)

层数的理解:
LSTM(长短期记忆)的层数指的是在神经网络中堆叠的LSTM单元的数量。层数决定了网络能够学习的复杂性和深度。每一层LSTM都能够捕捉和记忆不同时间尺度的依赖关系,因此增加层数可以使网络更好地理解和处理复杂的序列数据。
在这里插入图片描述

LSTM方法:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

x = torch.linspace(0, 999, 1000)
y = torch.sin(x*2*3.1415926/70)

plt.xlim(-5, 1005)
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.title("sin")
plt.plot(y.numpy(), color='#800080')
plt.show()

x = torch.linspace(0, 999, 1000)
y = torch.sin(x * 2 * 3.1415926 / 100) + 0.3 * torch.sin(x * 2 * 3.1415926 / 25) + 0.8 * np.random.normal(0, 1.5)

plt.plot(y.numpy(), color='#800080')
plt.title("Sine-Like Time Series")
plt.xlabel('Time')
plt.ylabel('Value')
plt.show()

train_y= y[:-70]
test_y = y[-70:]

def create_data_seq(seq, time_window):
    out = []
    l = len(seq)
    for i in range(l-time_window):
        x_tw = seq[i:i+time_window]
        y_label = seq[i+time_window:i+time_window+1]
        out.append((x_tw, y_label))
    return out
time_window = 60
train_data = create_data_seq(train_y, time_window)


class MyLstm(nn.Module):
    def __init__(self, input_size=1, hidden_size=128, out_size=1):
        super(MyLstm, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=self.hidden_size, num_layers=1, bidirectional=False)
        self.linear = nn.Linear(in_features=self.hidden_size, out_features=out_size, bias=True)
        self.hidden_state = (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))

    def forward(self, x):
        out, self.hidden_state = self.lstm(x.view(len(x), 1, -1), self.hidden_state)
        pred = self.linear(out.view(len(x), -1))
        return pred[-1]


time_window = 60
train_data = create_data_seq(train_y, time_window)

learning_rate = 0.00001
epoch = 13
multi_step = 70

model=MyLstm()
mse_loss = nn.MSELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate,betas=(0.5,0.999))

for i in range(epoch):
    for x_seq, y_label in train_data:
        x_seq = x_seq 
        y_label = y_label 
        model.hidden_state = (torch.zeros(1, 1, model.hidden_size) ,
                              torch.zeros(1, 1, model.hidden_size) )
        pred = model(x_seq)
        loss = mse_loss(y_label, pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {i} Loss: {loss.item()}")
    preds = []
    labels = []
    preds = train_y[-time_window:].tolist()
    for j in range(multi_step):
        test_seq = torch.FloatTensor(preds[-time_window:]) 
        with torch.no_grad():
            model.hidden_state = (torch.zeros(1, 1, model.hidden_size) ,
                                  torch.zeros(1, 1, model.hidden_size) )
            preds.append(model(test_seq).item())
    loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))
    print(f"Performance on test range: {loss}")

    plt.figure(figsize=(12, 4))
    plt.xlim(700, 999)
    plt.grid(True)
    plt.plot(y.numpy(), color='#8000ff')
    plt.plot(range(999 - multi_step, 999), preds[-multi_step:], color='#ff8000')
    plt.show()


class SimpleMLP(nn.Module):
    def __init__(self, input_size=60, hidden_size=128, output_size=1):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


mlp_model = SimpleMLP()
mse_loss = nn.MSELoss()
optimizer = torch.optim.Adam(mlp_model.parameters(), lr=0.0001)
for i in range(epoch):
    for x_seq, y_label in train_data:
        x_seq = x_seq
        y_label = y_label
        pred = mlp_model(x_seq)
        loss = mse_loss(y_label, pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {i} Loss: {loss.item()}")
    preds = []
    labels = []
    preds = train_y[-time_window:].tolist()
    for j in range(multi_step):
        test_seq = torch.FloatTensor(preds[-time_window:])
        with torch.no_grad():
            preds.append(mlp_model(test_seq).item())
    loss = mse_loss(torch.tensor(preds[-multi_step:]), torch.tensor(test_y))
    print(f"Performance on test range: {loss}")

    plt.figure(figsize=(12, 4))
    plt.xlim(700, 999)
    plt.grid(True)
    plt.plot(y.numpy(), color='#8000ff')
    plt.plot(range(999 - multi_step, 999), preds[-multi_step:], color='#ff8000')
    plt.show()

生成的一个带些随机数的正弦波:y = torch.sin(x * 2 * 3.1415926 / 100) + 0.3 * torch.sin(x * 2 * 3.1415926 / 25) + 0.8 * np.random.normal(0, 1.5)

结果发现:MLP效果比LSTM好?!
MLP:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
偶然有不是很准,但大部分非常准

LSTM:
就很奇怪?

在这里插入图片描述
在这里插入图片描述

但是如果是纯正弦波 y = torch.sin(x23.1415926/70) ,规律太明显了,好像效果都还行:
MLP:
简单聪明的MLP第一轮就学会了
在这里插入图片描述
LSTM:
开始几轮还有些懵
在这里插入图片描述
后边就悟了
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/402989.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SQL- left join 与group by联合使用实例

表:Visits ---------------------- | Column Name | Type | ---------------------- | visit_id | int | | customer_id | int | ---------------------- visit_id 是该表中具有唯一值的列。 该表包含有关光临过购物中心的顾客的信息。表&#xff1a…

Docker容器与虚拟化技术:kylin 部署 docker容器应用

目录 一、实验 1.环境 2. kylin 部署 docker及版本升级 3.kylin 部署docker镜像加速 4.kylin 部署 nginx容器应用 5.kylin使用docker容器部署mysql实现数据持久化 6.kylin使用docker容器部署nginx实现配置文件持久化到本地 7.kylin 使⽤ docker 部署容器可视化平台porta…

【青龙】快速搭建青龙面板,部署属于你自己的应用!

青龙面板是一个支持 Python3、JavaScript、Shell、Typescript 的定时任务管理平台。 废话不多说,直接开始。 这里使用一台 雨云 的云服务器作为演示。雨云注册地址:https://www.rainyun.com/ 优惠码:lz932 使用优惠码注册后绑定微信可获得8折…

Spring框架@Autowired注解进行字段时,使用父类类型接收子类变量,可以注入成功吗?(@Autowired源码跟踪)

一、 前言 平常我们在使用spring框架开发项目过程中,会使用Autowired注解进行属性依赖注入,一般我们都是声明接口类型来接收接口实现变量,那么使用父类类型接收子类变量,可以注入成功吗?答案是肯定可以的!…

从零学习Linux操作系统第二十七部分 shell脚本中的变量

一、什么是变量 变量的定义 定义本身 变量就是内存一片区域的地址 变量存在的意义 命令无法操作一直变化的目标 用一串固定的字符来表示不固定的目标可以解决此问题 二、变量的类型及命名规范 环境级别 export A1 在环境关闭后变量失效 退出后 关闭 用户级别&#xff…

Java项目:24 基于SpringBoot+freemarker实现的人事管理系统

作者主页:源码空间codegym 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 基于SpringBootfreemarker实现的人事管理系统分为七大模块:绩效考核,招聘管理,档案管理,工资管…

Marin说PCB之如何使用mentor--xpedition-Valor软件做gerber_compare

首先打开mentro_xpedition,自带的Valor软件。 2,在File栏中选择import---odb。 3,导入生成的DOB文件。 4,在这个界面下再重新导入一份之前的参考板的ODB文件进来。 5,接着点击STEPS---board,这样单板的数据就被调进来了。 6&#…

《剑指Offer》笔记题解思路技巧优化_Part_6

《剑指Offer》笔记&题解&思路&技巧&优化_Part_6 😍😍😍 相知🙌🙌🙌 相识😢😢😢 开始刷题🟡1.LCR 168. 丑数—— 丑数🟢2. LCR 16…

2022蓝帽杯取证初赛

检材:https://pan.baidu.com/s/1ibOdxyCWeC5x0DQKjwcz7w?pwdvg6g 目录 手机取证1、627604C2-C586-48C1-AA16-FF33C3022159.PNG图片的分辨率是?(答案参考格式:19201080)2、姜总的快递单号是多少?&#xff0…

C++学习Day09之异常变量的生命周期

目录 一、程序及输出1.1 throw MyException()------catch (MyException e)1.2 throw MyException()------catch (MyException &e)1.3 throw &MyException()------catch (MyException *e)1.4 throw new MyException()------catch (MyException *e) 二、分析与总结 一、程…

QT3作业

1 2. 使用手动连接,将登录框中的取消按钮使用qt4版本的连接到自定义的槽函数中,在自定义的槽函数中调用关闭函数,将登录按钮使用t5版本的连接到自定义的槽函数中,在槽函数中判断ui界面上输入的账号是否为"admin"&#…

【C++初阶】系统实现日期类

目录 一.运算符重载实现各个接口 1.小于 (d1)<> 2.等于 (d1d2) 3.小于等于&#xff08;d1<d2&#xff09; 4.大于&#xff08;d1>d2&#xff09; 5.大于等于&#xff08;d1>d2&#xff09; 6.不等于&#xff08;d1!d2&#xff09; 7.日期天数 (1) 算…

顺序表详解(如何实现顺序表)

文章目录 前言 在进入顺序表前&#xff0c;我们先要明白&#xff0c;数据结构的基本概念。 一、数据结构的基本概念 1.1什么是数据结构 数据结构是由“数据”和“结构”两词组合而来。所谓数据就是&#xff1f;常见的数值1、2、3、4.....、姓名、性别、年龄&#xff0c;等。…

学习总结22

解题思路 简单模拟。 代码 #include <bits/stdc.h> using namespace std; long long g[2000000]; long long n; int main() {long long x,y,z,sum0,k0;scanf("%lld",&n);for(x1;x<n;x)scanf("%lld",&g[x]);for(x1;x<n;x){scanf(&qu…

尚未创建默认 SSL 站点。若要支持不带 SNI 功能的浏览器,建议创建一个默认 SSL 站点。

在 Windows Server 2012 IIS 站点中设置 SSL 证书后&#xff0c;IIS 右上角提示&#xff1a; 尚未创建默认 SSL 站点。若要支持不带 SNI 功能的浏览器&#xff0c;建议创建一个默认 SSL 站点。 该提示客户忽略不管&#xff0c;但是若要支持不带 SNI(Server Name Indication)…

外卖柜平台的设计与实现以及实践与总结

近年来&#xff0c;外卖行业的快速发展推动了外卖配送行业的进步和创新。外卖柜平台作为一种新兴的配送方式&#xff0c;在提高配送效率和服务质量方面具有很大的优势。本文将探讨美团外卖柜平台的设计与实现&#xff0c;以及如何保障其稳定性和安全性。 架构设计 美团外柜平台…

React -- useEffect

React - useEffect 概念理解 useEffect是一个React Hook函数&#xff0c;用于在React组件中创建不是由事件引起而是由渲染本身引起的操作&#xff08;副作用&#xff09;, 比 如发送AJAX请求&#xff0c;更改DOM等等 :::warning 说明&#xff1a;上面的组件中没有发生任何的用…

市场复盘总结 20240222

仅用于记录当天的市场情况&#xff0c;用于统计交易策略的适用情况&#xff0c;以便程序回测 短线核心&#xff1a;不参与任何级别的调整&#xff0c;采用龙空龙模式 一支股票 10%的时候可以操作&#xff0c; 90%的时间适合空仓等待 二进三&#xff1a; 进级率中 25% 最常用…

leetcode日记(32)字符串相乘

做了很久很久……真的太繁琐了&#xff01;&#xff01; class Solution { public:string multiply(string num1, string num2) {string s;string str;if (num1 "0" || num2 "0") return "0";for(int inum2.size()-1;i>0;i--){int c2num2[…

银行项目网上支付接口调用测试实例

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 公司最近有一个网站商城项目要开始开发了&#xff0c;这几天老板和几个同事一起开着需求会议&…