基于LSTM及其变体的回归预测

1 所用模型

       代码中用到了以下模型:

      1. LSTM(Long Short-Term Memory):长短时记忆网络,是一种特殊的RNN(循环神经网络),能够解决传统RNN在处理长序列时出现的梯度消失或爆炸的问题。LSTM有门控机制,可以选择性地记住或忘记信息。

       2. FC-LSTM:全连接的LSTM,与传统的LSTM相比,其细胞单元之间采用全连接的方式。

       3. Coupled LSTM:耦合LSTM,是一种特殊的LSTM结构,其中每个LSTM单元被分解为两个交互的子单元。

       4. GRU(Gated Recurrent Unit):门控循环单元,与LSTM类似,但结构更简单,参数更少,通常训练更快,但可能不如LSTM准确。

       5. ConvLSTM:卷积LSTM,将卷积神经网络(CNN)与LSTM结合,可以捕捉时空特征,常用于处理图像和视频数据。

       6. Deep LSTM:深层LSTM,包含多个LSTM层的堆叠,可以捕捉更复杂的模式。

       7. DB-LSTM(Bidirectional LSTM):双向LSTM,有两个方向的LSTM层,一个按时间顺序,一个逆序,可以同时获取过去和未来的信息。

       8. SRU(SimpleRNN):简单循环神经网络,是最基本的RNN形式。

       9. TPA-LSTM:时间感知LSTM,通过改变LSTM的内部计算方式,使其更加关注时间序列的特性。

       10. ConvGRU:卷积GRU,与ConvLSTM类似,但使用GRU代替LSTM。

       这些模型都是用于处理序列数据的深度学习模型,特别适用于时间序列预测、自然语言处理等领域。

2 运行结果

       左边是Epoch=50次的效果,右边是Epoch=15次的效果:

a1e88c48c6f645eea96360f59b239c00.jpg

 图2-1 训练损失

3623cb88b9294ce796d7dbacd244f481.jpg

 图2-2 测试损失

d9ab03d1196542bf9235bafc58288e07.jpg

 图2-3 预测结果

3 代码

     

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LSTM, GRU, SimpleRNN, Bidirectional, TimeDistributed, Conv1D, Attention
from keras.layers import Flatten, Dropout, BatchNormalization
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
from tensorflow.keras.layers import Conv1D
# 读取数据
data = pd.read_excel('A.xlsx')
data=data.dropna()
data = data['A'].values.reshape(-1, 1)
# 数据预处理
scaler = MinMaxScaler()
data = scaler.fit_transform(data)

# 划分训练集和测试集
train_size = int(len(data) * 0.8)
train, test = data[:train_size], data[train_size:]

# 转换数据格式以适应LSTM输入
def create_dataset(dataset, look_back=1):
    X, Y = [], []
    for i in range(len(dataset) - look_back - 1):
        X.append(dataset[i:(i + look_back), 0])
        Y.append(dataset[i + look_back, 0])
    return np.array(X), np.array(Y)
 
look_back = 1
X_train, y_train = create_dataset(train, look_back)
X_test, y_test = create_dataset(test, look_back)
 
# 重塑输入数据的维度以适应LSTM模型
X_train = np.reshape(X_train, (X_train.shape[0], 1, X_train.shape[1]))
X_test = np.reshape(X_test, (X_test.shape[0], 1, X_test.shape[1]))
# 定义模型函数
def create_model(name):
    model = Sequential()
    if name == 'LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1)))
    elif name == 'FC-LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1), recurrent_activation='sigmoid'))
    elif name == 'Coupled LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1), implementation=2))
    elif name == 'GRU':
        model.add(GRU(50, activation='relu', input_shape=(1, 1)))
    elif name == 'ConvLSTM':
        model.add(Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=(1, 1)))
        model.add(LSTM(50, activation='relu'))
    elif name == 'Deep LSTM':
        model.add(LSTM(50, return_sequences=True, activation='relu', input_shape=(1, 1)))
        model.add(LSTM(50, activation='relu'))
    elif name == 'DB-LSTM':
        model.add(Bidirectional(LSTM(50, activation='relu'), input_shape=(1, 1)))
    elif name == 'SRU':
        model.add(SimpleRNN(50, activation='relu', input_shape=(1, 1)))
    elif name == 'TPA-LSTM':
        model.add(LSTM(50, activation='relu', input_shape=(1, 1), unroll=True))
    elif name == 'ConvGRU':
        model.add(Conv1D(filters=64, kernel_size=1, activation='relu', input_shape=(1, 1)))
        model.add(GRU(50, activation='relu'))
    model.add(Dense(1))
    model.compile(optimizer=Adam(), loss='mse')
    return model

# 训练模型并绘制损失图
names = ['LSTM', 'FC-LSTM', 'Coupled LSTM', 'GRU', 'ConvLSTM', 'Deep LSTM', 'DB-LSTM','SRU', 'TPA-LSTM', 'ConvGRU']
train_losses = []
test_losses = []
predictions = []

for name in names:
    model = create_model(name)
    history = model.fit(train, train, epochs=15, batch_size=32, validation_data=(test, test), verbose=0)
    train_losses.append(history.history['loss'])
    test_losses.append(history.history['val_loss'])
    pred = model.predict(test)
    predictions.append(pred)
    
    
import matplotlib.pyplot as plt

# 设置不同的marker
markers = ['o', '.', '_', '^', '*', '>', '+', '1', 'p', '_', '8']
linestyles = ['-', '--', '--', ':', '-', '-.', '-.', ':', '-', '--']
# 绘制训练损失图
plt.figure(figsize=(16, 20))
for i, loss in enumerate(train_losses):
    plt.plot(loss, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
plt.title('Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(fontsize=8,loc='best')
plt.show()
# 绘制测试损失图
for i, loss in enumerate(test_losses):
    plt.plot(loss, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
plt.title('Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(fontsize=8,loc='best')
plt.show()
# 绘制预测结果折线图
for i, pred in enumerate(predictions):
    plt.plot(pred, color='black',label=names[i], marker=markers[i], linestyle=linestyles[i])
# 绘制真实值折线图
plt.plot(y_test, color='black', label='True Value')
plt.title('Predictions and True Values')
plt.xlabel('x')
plt.ylabel('value')
plt.legend(fontsize=8, loc='best')
# 显示图像
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/802386.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MBR40150FCT-ASEMI无人机专用MBR40150FCT

编辑:ll MBR40150FCT-ASEMI无人机专用MBR40150FCT 型号:MBR40150FCT 品牌:ASEMI 封装:TO-220F 批号:最新 最大平均正向电流(IF):40A 最大循环峰值反向电压(VRRM&a…

typeorm实体多对多关系指定表名与关联字段

表结构 user 用户表结构 course 课程表结构 user_course 用户课程表 (每个用户可以有多个课程, 每个课程可以有多个用户, 该表用以建立多对多关系) 实体 user.entity.ts Entity(user, { schema: test }) export class User {PrimaryGeneratedColumn({ type: int, name: id }…

江科大SPI教程听课笔记

原理部分我打算听江科大的课复习一下,代码部分工作大概率用HAL库敲了。 SPI(Serial Peripheral Interface)是由Motorola公司开发的一种通用数据总线。 硬件资源方面需要四根通信线:SCK(Serial Clock)、MOSI(Master Output Slave Input)、MISO (Master Input Slave…

自定义组件--密码修改对话框(拿来即用型)

前言 一个完整的系统中用户登录功能是不可或缺的,因此用户密码的修改对于前端开发者而言也是工作的重要一环,密码修改分为两种情况:一是用户自身想更换密码;另一种是忘记密码只能选择更换密码。本文自定义了一个通用且常见的组件-…

IDEA快速生成项目树形结构图

下图用的IDEA工具,但我觉得WebStorm 应该也可以 文章目录 进入项目根目录下,进入cmd输入如下指令: 只有文件夹 tree . > list.txt 包括文件夹和文件 tree /f . > list.txt 还可以为相关包路径加上注释

【STM32嵌入式系统设计与开发---拓展】——1_9_1上拉输入和下拉输入

在使用GPIO引脚时,上拉输入和下拉输入的选择取决于外部电路的特性和应用需求。以下是它们各自的应用场景: 1、上拉输入(Pull-up Input) 用途: 当默认状态需要为高电平时。 避免引脚悬空(floating)导致的…

Three.JS 使用RGBELoader和CubeTextureLoader 添加环境贴图

导入RGBELoader模块: import { RGBELoader } from "three/examples/jsm/loaders/RGBELoader.js"; 使用 addRGBEMappingk(environment, background,url) {rgbeLoader new RGBELoader();rgbeLoader.loadAsync(url).then((texture) > {//贴图模式 经纬…

MongoDB教程(八):mongoDB数据备份与恢复

💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 文章目录 引言MongoDB 备…

socket功能定义和一般模型

1. socket的功能定义 socket是为了使两个应用程序间进行数据交换而存在的一种技术,不仅可以使同一个主机上两个应用程序间可以交换数据,而且可以使网络上的不同主机间上的应用程序间进行通信。 2. 图解socket的服务端/客户端模型

深度学习落地实战:基于UNet实现血管瘤超声图像分割

前言 大家好,我是机长 本专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代…

cpp 强制转换

一、static_cast static_cast 是 C 中的一个类型转换操作符,用于在类的层次结构中进行安全的向上转换(从派生类到基类)或进行不需要运行时类型检查的转换。它主要用于基本数据类型之间的转换、对象指针或引用的向上转换(即从派生…

【Redis】集群

文章目录 一、集群是什么?二、 Redis集群分布式存储为什么redis集群的最大槽数是16384(不太懂)redis的集群主节点数量基本不可能超过1000个 三、 配置集群(三主三从)3.1 配置config文件3.2 启动六台redis3.2 通过redis…

铜管和铝管、铝管和铝管焊接操作介绍

一、部分品牌冰箱、空调采用铜铝管或铝铝管之间的连接方式,连接方式有以下两种: 1、洛克环:是方便简单的方式,但其需从德国采购,成本过于高昂而且采购周期长; 2、铜铝异种材料钎焊技术:国内可…

怎样在 PostgreSQL 中优化对大表的索引创建和维护的性能开销?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会!📚领书:PostgreSQL 入门到精通.pdf 文章目录 怎样在 PostgreSQL 中优化对大表的索引创建和维护的性能开销?一、理解大表和索引的概念&am…

[C++]——同步异步日志系统(7)

同步异步日志系统 一、日志器管理模块(单例模式)1.1 对日志器管理器进行设计1.2 实现日志器管理类的各个功能1.3. 设计一个全局的日志器建造者1.4 测试日志器管理器的接口和全局建造者类 二、宏函数和全局接口设计2.1 新建一个.h,文件,文件里面放我们写的…

小欧吃苹果-OPPO 2024届校招正式批笔试题-数据开发(C卷)

在处理这个问题前&#xff0c;先看一个经典的贪心算法题目。信息学奥赛一本通&#xff08;C版&#xff09;在线评测系统http://ybt.ssoier.cn:8088/problem_show.php?pid1320 注意移动纸牌的贪心策略并不是题目中给出的移动次序&#xff1a;第1堆纸牌9<10&#xff0c;因为是…

几何相关计算

目录 一、 判断两个矩形是否相交 二、判断两条线段是否相交 三、判断点是否在多边形内 四、垂足计算 五、贝塞尔曲线 六、坐标系 一、 判断两个矩形是否相交 当矩形1的最大值比矩形2的最小值都小&#xff0c;那矩形1和矩形2一定不相交&#xff0c;其他同理。 struct Po…

【STM32】按键控制LED光敏传感器控制蜂鸣器(江科大)

一、按键控制LED LED.c #include "stm32f10x.h" // Device header/*** 函 数&#xff1a;LED初始化* 参 数&#xff1a;无* 返 回 值&#xff1a;无*/ void LED_Init(void) {/*开启时钟*/RCC_APB2PeriphClockCmd(RCC_APB2Periph_GPIOA, ENAB…

醇香之旅:探索红酒的无穷魅力

在浩渺的饮品世界里&#xff0c;红酒如同一颗璀璨的星辰&#xff0c;闪烁着诱人的光芒。它以其不同的醇香和深邃的韵味&#xff0c;吸引着无数人的目光。今天&#xff0c;就让我们一起踏上这场醇香之旅&#xff0c;探索雷盛红酒所带来的无穷魅力。 一、初识红酒的醇香 当我们…

去除重复字母

题目链接 去除重复字母 题目描述 注意点 s 由小写英文字母组成1 < s.length < 10^4需保证 返回结果的字典序最小&#xff08;要求不能打乱其他字符的相对位置&#xff09; 解答思路 本题与移掉 K 位数字类似&#xff0c;需要注意的是&#xff0c;并不是每个字母都能…