1-4.时间序列数据建模流程范例

文章最前: 我是Octopus,这个名字来源于我的中文名–章鱼;我热爱编程、热爱算法、热爱开源。所有源码在我的个人github
;这博客是记录我学习的点点滴滴,如果您对 Python、Java、AI、算法有兴趣,可以关注我的动态,一起学习,共同进步。

2020年发生的新冠肺炎疫情灾难给各国人民的生活造成了诸多方面的影响。

有的同学是收入上的,有的同学是感情上的,有的同学是心理上的,还有的同学是体重上的。

本文基于中国2020年3月之前的疫情数据,建立时间序列RNN模型,对中国的新冠肺炎疫情结束时间进行预测。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

import torch 
print("torch.__version__ = ", torch.__version__)


torch.__version__ =  2.0.1

公众号 算法美食屋 回复关键词:pytorch, 获取本项目源码和所用数据集百度云盘下载链接。

import os

#mac系统上pytorch和matplotlib在jupyter中同时跑需要更改环境变量
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 


一,准备数据

本文的数据集取自tushare,获取该数据集的方法参考了以下文章。

《https://zhuanlan.zhihu.com/p/109556102》

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

df = pd.read_csv("./eat_pytorch_datasets/covid-19.csv",sep = "\t")
df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60);

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

dfdata = df.set_index("date")
dfdiff = dfdata.diff(periods=1).dropna()
dfdiff = dfdiff.reset_index("date")

dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60)
dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

dfdiff.head()
confirmed_numcured_numdead_num
0457.04.016.0
1688.011.015.0
2769.02.024.0
31771.09.026.0
41459.043.026.0

下面我们通过继承torch.utils.data.Dataset实现自定义时间序列数据集。

torch.utils.data.Dataset是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的两个方法即可:

  • __len__:实现len(dataset)返回整个数据集的大小。
  • __getitem__:用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。

不覆写这两个方法会直接返回错误。

import torch 
from torch import nn 
from torch.utils.data import Dataset,DataLoader,TensorDataset


#用某日前8天窗口数据作为输入预测该日数据
WINDOW_SIZE = 8

class Covid19Dataset(Dataset):
        
    def __len__(self):
        return len(dfdiff) - WINDOW_SIZE
    
    def __getitem__(self,i):
        x = dfdiff.loc[i:i+WINDOW_SIZE-1,:]
        feature = torch.tensor(x.values)
        y = dfdiff.loc[i+WINDOW_SIZE,:]
        label = torch.tensor(y.values)
        return (feature,label)
    
ds_train = Covid19Dataset()

#数据较小,可以将全部训练数据放入到一个batch中,提升性能
dl_train = DataLoader(ds_train,batch_size = 38)

for features,labels in dl_train:
    break 
    
#dl_train同时作为验证集
dl_val = dl_train

二,定义模型

使用Pytorch通常有三种方式构建模型:使用nn.Sequential按层顺序构建模型,继承nn.Module基类构建自定义模型,继承nn.Module基类构建模型并辅助应用模型容器进行封装。

此处选择第二种方式构建模型。

import torch
from torch import nn 
import importlib 
import torchkeras 

torch.random.seed()

class Block(nn.Module):
    def __init__(self):
        super(Block,self).__init__()
    
    def forward(self,x,x_input):
        x_out = torch.max((1+x)*x_input[:,-1,:],torch.tensor(0.0))
        return x_out
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 3层lstm
        self.lstm = nn.LSTM(input_size = 3,hidden_size = 3,num_layers = 5,batch_first = True)
        self.linear = nn.Linear(3,3)
        self.block = Block()
        
    def forward(self,x_input):
        x = self.lstm(x_input)[0][:,-1,:]
        x = self.linear(x)
        y = self.block(x,x_input)
        return y
        
net = Net()
print(net)


Net(
  (lstm): LSTM(3, 3, num_layers=5, batch_first=True)
  (linear): Linear(in_features=3, out_features=3, bias=True)
  (block): Block()
)
Net(
  (lstm): LSTM(3, 3, num_layers=5, batch_first=True)
  (linear): Linear(in_features=3, out_features=3, bias=True)
  (block): Block()
)
from torchkeras import summary
summary(net,input_data=features);
--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
==========================================================================
LSTM-1                                    [-1, 8, 3]                  480
Linear-2                                     [-1, 3]                   12
Block-3                                      [-1, 3]                    0
==========================================================================
Total params: 492
Trainable params: 492
Non-trainable params: 0
--------------------------------------------------------------------------
Input size (MB): 0.000069
Forward/backward pass size (MB): 0.000229
Params size (MB): 0.001877
Estimated Total Size (MB): 0.002174
--------------------------------------------------------------------------

三,训练模型

训练Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异。

有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环。

此处我们通过引入torchkeras库中的KerasModel工具来训练模型,无需编写自定义循环。

torchkeras详情: https://github.com/lyhue1991/torchkeras

注:循环神经网络调试较为困难,需要设置多个不同的学习率多次尝试,以取得较好的效果。

from torchmetrics.regression import MeanAbsolutePercentageError

def mspe(y_pred,y_true):
    err_percent = (y_true - y_pred)**2/(torch.max(y_true**2,torch.tensor(1e-7)))
    return torch.mean(err_percent)

net = Net() 
loss_fn = mspe
metric_dict = {"mape":MeanAbsolutePercentageError()}

optimizer = torch.optim.Adam(net.parameters(), lr=0.03)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.0001)
from torchkeras import KerasModel 
model = KerasModel(net,
       loss_fn = loss_fn,
       metrics_dict= metric_dict,
       optimizer = optimizer,
       lr_scheduler = lr_scheduler) 

dfhistory = model.fit(train_data=dl_train,
            val_data=dl_val,
            epochs=100,
            ckpt_path='checkpoint',
            patience=10,
            monitor='val_loss',
            mode='min',
            callbacks=None,
            plot=True,
            cpu=True
            )

[0;31m<<<<<< 🐌 cpu is used >>>>>>[0m

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

18.00% [18/100] [00:02<00:10]
████████████████████100.00% [1/1] [val_loss=0.4363, val_mape=0.5570]
[0;31m<<<<<< val_loss without improvement in 10 epoch,early stopping >>>>>> 
[0m

四,评估模型

评估模型一般要设置验证集或者测试集,由于此例数据较少,我们仅仅可视化损失函数在训练集上的迭代情况。

model.evaluate(dl_val)

100%|█████████████████████████████████| 1/1 [00:00<00:00, 63.91it/s, val_loss=0.384, val_mape=0.505]





{'val_loss': 0.38373321294784546, 'val_mape': 0.5048269033432007}

五,使用模型

此处我们使用模型预测疫情结束时间,即 新增确诊病例为0 的时间。

#使用dfresult记录现有数据以及此后预测的疫情数据
dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()
dfresult.tail()
confirmed_numcured_numdead_num
41143.01681.030.0
4299.01678.028.0
4344.01661.027.0
4440.01535.022.0
4519.01297.017.0
#预测此后1000天的新增走势,将其结果添加到dfresult中
for i in range(1000):
    arr_input = torch.unsqueeze(torch.from_numpy(dfresult.values[-38:,:]),axis=0)
    arr_predict = model.forward(arr_input)

    dfpredict = pd.DataFrame(torch.floor(arr_predict).data.numpy(),
                columns = dfresult.columns)
    dfresult = pd.concat([dfresult,dfpredict],ignore_index=True)
dfresult.query("confirmed_num==0").head()

# 第50天开始新增确诊降为0,第45天对应3月10日,也就是5天后,即预计3月15日新增确诊降为0
# 注:该预测偏乐观

confirmed_numcured_numdead_num
500.0999.00.0
510.0948.00.0
520.0900.00.0
530.0854.00.0
540.0810.00.0

dfresult.query("cured_num==0").head()
# 第137天开始新增治愈降为0,第45天对应3月10日,也就是大概3个月后,即6月12日左右全部治愈。
# 注: 该预测偏悲观,并且存在问题,如果将每天新增治愈人数加起来,将超过累计确诊人数。
confirmed_numcured_numdead_num
1370.00.00.0
1380.00.00.0
1390.00.00.0
1400.00.00.0
1410.00.00.0

六,保存模型

模型权重保存在了model.ckpt_path路径。

print(model.ckpt_path)
checkpoint
model.load_ckpt('checkpoint') #可以加载权重

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/763432.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

已解决java.io.NotSerializableException:对象不支持序列化的正确解决方法,亲测有效!!!

已解决java.io.NotSerializableException&#xff1a;对象不支持序列化的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 目录 问题分析 出现问题的场景 示例代码 报错原因 解决思路 解决方法 1. 实现Serializable接口 修改后的Employee类 2…

递归----计算P函数

注意运算中的符号不能少&#xff01;&#xff01;&#xff01;&#xff01; * 必须体现出&#xff01;&#xff01;&#xff01;&#xff01; #include <stdio.h>double P( int n, double x );int main() {int n;double x;scanf("%d %lf", &n, &x);pri…

计算机毕业设计Python+Spark股票基金推荐与预测系统 股票基金可视化 股票基金推荐系统 股票基金可视化系统 股票基金数据分析 股票基金爬虫大数据

目 录 摘 要 Abstract 第1章 前 言 1.1 项目的背景和意义 1.2 研究现状 1.3 项目的目标和范围 1.4 论文结构简介 第2章 技术与原理 2.1 开发原理 2.2 开发工具 2.3 关键技术 第3章 需求建模 3.1 系统可行性分析 3.2 功能需求分析 3.3 非功能性…

opengl箱子的显示

VS环境配置&#xff1a; /JMC /ifcOutput "Debug\" /GS /analyze- /W3 /Zc:wchar_t /I"D:\Template\glfwtemplate\glfwtemplate\assimp" /I"D:\Template\glfwtemplate\glfwtemplate\glm" /I"D:\Template\glfwtemplate\glfwtemplate\LearnOp…

Wireshark - tshark支持iptables提供数据包

tshark现在的数据包获取方式有两种&#xff0c;分别是读文件、网口监听&#xff08;af-packet原始套接字&#xff09;。两种方式在包获取上&#xff0c;都是通过读文件的形式&#xff1b;存在文件io操作&#xff0c;在专门处理大流量的情境下&#xff0c; 我们复用wireshark去做…

小阿轩yx-案例:MySQL主从复制与读写分离

小阿轩yx-案例&#xff1a;MySQL主从复制与读写分离 案例分析 概述 实际生产环境中 如果对数据库读和写都在同一个数据库服务器中操作&#xff0c;无论在安全性、高可用性还是高并发等各个方面都完全不能满足实际需求一般都是通过主从复制&#xff08;Master-Slave&#xf…

Python tkinter: 开发一个目标检测GUI小程序

程序提供了一个用户友好的界面&#xff0c;允许用户选择图片或文件夹&#xff0c;使用行人检测模型进行处理&#xff0c;并在GUI中显示检测结果。用户可以通过点击画布上的检测结果来获取更多信息&#xff0c;并使用键盘快捷键来浏览不同的图片。 一. 基本功能介绍 界面布局&am…

C++封装

1. 封装 1.1. struct 当单一变量无法完成描述需求的时候&#xff0c;结构体类型解决了这一问题。可以将多个类型打包成一体&#xff0c;形成新的类型&#xff0c;这是c语言中的封装 但是&#xff0c;新类型并不包含&#xff0c;对数据类的操作。所有操作都是通过函数的方式进…

CrimsonEDR:一款恶意软件模式识别与EDR策略评估工具

关于CrimsonEDR CrimsonEDR是一个功能强大的开源项目&#xff0c;该项目旨在帮助广大研究人员识别特定的恶意软件模式&#xff0c;以此来优化终端检测与响应&#xff08;EDR&#xff09;的策略方案。通过使用各种不同的检测方案&#xff0c;可以加深开发人员与研究人员加深对安…

在Ubuntu 14.04上安装和配置Mumble服务器(Murmur)的方法

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 介绍 Mumble是一款免费开源的语音通信应用程序&#xff0c;主要设计用于游戏玩家使用。Mumble类似于TeamSpeak和Ventrilo。Mumble采用客…

考研生活day1--王道课后习题2.2.1、2.2.2、2.2.3

2.2.1 题目描述&#xff1a; 解题思路&#xff1a; 这是最基础的操作&#xff0c;思路大家应该都有&#xff0c;缺少的应该是如何下笔&#xff0c;很多同学都是有思路但是不知道如何下笔&#xff0c;这时候看思路的意义不大&#xff0c;可以直接看答案怎么写&#xff0c;最好…

cube-studio 开源一站式云原生机器学习/深度学习/大模型训练推理平台介绍

全栈工程师开发手册 &#xff08;作者&#xff1a;栾鹏&#xff09; 一站式云原生机器学习平台 前言 开源地址&#xff1a;https://github.com/tencentmusic/cube-studio cube studio 腾讯开源的国内最热门的一站式机器学习mlops/大模型训练平台&#xff0c;支持多租户&…

python sklearn机械学习模型-分类

&#x1f308;所属专栏&#xff1a;【机械学习】✨作者主页&#xff1a; Mr.Zwq✔️个人简介&#xff1a;一个正在努力学技术的Python领域创作者&#xff0c;擅长爬虫&#xff0c;逆向&#xff0c;全栈方向&#xff0c;专注基础和实战分享&#xff0c;欢迎咨询&#xff01; 您…

什么是应用安全态势管理 (ASPM):综合指南

软件开发在不断发展&#xff0c;应用程序安全也必须随之发展。 传统的应用程序安全解决方案无法跟上当今开发人员的工作方式或攻击者的工作方式。 我们需要一种新的应用程序安全方法&#xff0c;而ASPM在该方法中发挥着关键作用。 什么是 ASPM&#xff1f; 应用程序安全…

神经网络训练(一):基于残差连接的图片分类网络(ResNet18)

目录 一、简介:二、图片分类网络1.记载训练数据(torch自带的cifa10数据集)2.数据增强3.模型构建4.模型训练三、完整源码及文档一、简介: 基于残差连接的图片分类网络,本网络使用ResNet18作为基础模块,根据cifa10的特点进行改进网络,使用交叉熵损失函数和SGD优化器。本网…

源代码层面分析Appium-inspector工作原理

Appium-inspector功能 Appium Inspector 基于 Appium 框架&#xff0c;Appium 是一个开源工具&#xff0c;用于自动化移动应用&#xff08;iOS 和 Android&#xff09;和桌面应用&#xff08;Windows 和 Mac&#xff09;。Appium 采用了客户端-服务器架构&#xff0c;允许用户通…

实践Go的命令模式

简介 现在的软件系统往往是分层设计。在业务层执行一次请求时&#xff0c;我们很清楚请求的上下文&#xff0c;包括&#xff0c;请求是做什么的、参数有哪些、请求的接收者是谁、返回值是怎样的。相反&#xff0c;基础设施层并不需要完全清楚业务上下文&#xff0c;它只需知道…

Typora导出为Word

文章目录 一、场景二、安装1、网址2、解压并验证 三、配置四、重启Typora 一、场景 在使用Typora软件编辑文档时&#xff0c;我们可能需要将其导出为Word格式文件 当然我们可以直接在菜单里进行导出操作 文件-> 导出-> Word(.docx) 如果是第一次导出word文件&#xff0…

Python实现接糖果小游戏

介绍: 基于Pygame的糖果从屏幕顶部下落的游戏代码。这个游戏包括了一个可以左右移动的篮子来接住下落的糖果&#xff0c;接住糖果会增加得分。 代码: import pygame import random import os# 初始化pygame和设置屏幕大小 pygame.init() screen_width, screen_height 800, 6…

数据资产的创新应用与未来展望:探讨数据资产在人工智能、物联网等新兴领域的应用前景,提出前瞻性的数据资产解决方案,为企业探索新的增长点,推动行业创新发展

目录 一、引言 二、数据资产在人工智能领域的应用 1、机器学习与深度学习 2、自然语言处理 3、计算机视觉 三、数据资产在物联网领域的应用 1、智能家居 2、工业物联网 3、智慧城市 四、前瞻性的数据资产解决方案 1、构建统一的数据管理平台 2、加强数据安全和隐私…