深度学习(13)--PyTorch搭建神经网络进行气温预测

一.搭建神经网络进行气温预测流程详解

1.1.导入所需的工具包

import numpy as np  # 矩阵计算
import pandas as pd   # 数据读取
import matplotlib.pyplot as plt  # 画图处理
import torch  # 构建神经网络
import torch.optim as optim  # 设置优化器

1.2.读取并处理数据

引入数据并查看数据的格式

# 引入数据
features = pd.read_csv('temps.csv')

# 看看数据长什么样子
print(features.head())

Pandas库中的.head()函数,取数据的前n行数据,默认是取前五行数据,如上图所示。

查看数据维度

print('数据维度:', features.shape)

shape函数的功能是读取矩阵的长度,.shape直接输出数据的维度,如上图,表示该数据的维度为348行,9列。对应的也就是348个样本,9个特征。

而shape[0],shape[1]则分别返回矩阵第一维度、第二维度的长度:

# 查看数据维度
print('数据维度:', features.shape[0])
print('数据维度:', features.shape[1])

处理时间数据

# 处理时间数据
import datetime

# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']

# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

 查看处理的datas数据格式

print(dates[:3])

对特殊数据进行one-hot编码 

计算机无法识别字符串数据,所以对于字符串数据需要使用one-hot编码:

features = pd.get_dummies(features) 

# get_dummies会自动判断数据中哪一列是字符串,并自动将字符串展开。

# eg:数据中用于标注星期的字符串一共有七个,则get_dummies函数将数据展开成七列,当天是哪一天就在相应位置标1。

# 星期 一 二 三 四 五 六 七,如果是星期一则标注为:1 0 0 0 0 0 0,如果是星期三则标注为:0 0 1 0 0 0 0,如果是星期六则标注为:0 0 0 0 0 1 0

 查看one-hot编码后的数据

对标签进行处理

# 标签
labels = np.array(features['actual'])  # 获取标签:features获取actual的标签然后再转换为np.array的格式

# 在特征中去掉标签
features= features.drop('actual', axis = 1)  # 去除features中的actual标签,axis表示沿着行/列去除,axis=0按行计算,axis=1按列计算

# 名字单独保存一下,以备后患
feature_list = list(features.columns)  # 保存features中的columns值,也就是列

# 转换成合适的格式
features = np.array(features)  # 把处理后的features数据也转换为np.array格式

标准化处理 

不同的数据取值范围不同,而机器又会认为数值大的数据较为重要,所以需要对数据进行标准化(x-μ/σ) -- μ为均值,σ为标准差。

from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)  # fit_transform通过数据计算出均值和标准差,再对数据进行标准化处理变换。

fit_transform通过数据计算出均值和标准差,再对数据进行标准化处理变换。

标准化处理前后的数据:

1.3.构建网络模型

构建网络

本项目构建的网络模型较为简单,只有一个隐层

# shape[0]是样本数,也就是行的数据,shape[1]是特征数,也就是列的数据
input_size = input_features.shape[1]  
hidden_size = 128
output_size = 1
batch_size = 16  # 一次迭代batch个样本
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),  # 根据输入自动初始化权重参数和偏重值
    torch.nn.ReLU(),  # 激活函数 Sigmoid/Relu
    torch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')  # 损失函数设置:MSE均方误差
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)  
# 优化器设置:Adam,参数为网络中的所有参数以及学习率

训练网络

# 训练网络
losses = []
# 迭代1000次,epoch = 1000
for i in range(1000):
    batch_loss = []
    # MINI-Batch方法来进行训练
    for start in range(0, len(input_features), batch_size):  # 循环范围为0~样本数,每次循环中间间隔batchs_size
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)  # 做一个索引是否越界的判断
        # 取得一个batch的数据
        xx = torch.tensor(input_features[start:end], dtype = torch.float, requires_grad = True)
        yy = torch.tensor(labels[start:end], dtype = torch.float, requires_grad = True)
        prediction = my_nn(xx)  # 输入值经过定义的网络运算得到预测值
        loss = cost(prediction, yy)  # 参数为预测值和真实值
        optimizer.zero_grad()  # torch的迭代过程中会累计之前的训练结果,所以在每次迭代中需要清空梯度值
        loss.backward(retain_graph=True)  # 反向传播
        optimizer.step()  # 对所有参数进行更新
        batch_loss.append(loss.data.numpy())
    
    # 打印损失
    if i % 100==0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))

预测训练结果 

x = torch.tensor(input_features, dtype = torch.float)  
# 先将数据转换为tensor格式,因为需要在网络中进行运算
predict = my_nn(x).data.numpy()  
# 在网络中运算完成中,再转换为data.numpy格式,因为后续需要进行画图处理

1.4.对结果进行画图对比

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data={'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)})   # predict是x经过网络训练再转换为np.array的值

# 画图

# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label='actual')  # 参数分别为:横轴,纵轴,曲线颜色,标签值

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label='prediction')  # 参数分别为:横轴,纵轴,曲线颜色,标签值
plt.xticks(rotation=30)  # x轴参数倾斜60°
plt.legend()  # 使上述代码产生效果

# 图名
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')  # x,y轴标签设置
plt.title('Actual and Predicted Values')  # 图名设置

# 保存图片并展示
plt.savefig("result.png")
plt.show()

二.完整代码

import numpy as np  # 矩阵计算
import pandas as pd   # 数据读取
import matplotlib.pyplot as plt  # 画图处理
import torch  # 构建神经网络
import torch.optim as optim  # 设置优化器

# 处理时间数据
import datetime

from sklearn import preprocessing

# 引入数据
features = pd.read_csv('temps.csv')

# 看看数据长什么样子
# print(features.head())

# 查看数据维度
# print('数据维度:', features.shape)


# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']

'''
# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

print(dates[:3])
'''

# 独热(one-hot)编码 -- 机器不认识字符串,需要将字符串转换为机器认识的参数
features = pd.get_dummies(features)
# get_dummies会自动判断数据中哪一列是字符串,并自动将字符串展开
# eg:数据中用于标注星期的字符串一共有七个,则get_dummies函数将数据展开成七列,当天是哪一天就在相应位置标1
# 星期 一 二 三 四 五 六 七,如果是星期一则标注为:1 0 0 0 0 0 0,如果是星期三则标注为:0 0 1 0 0 0 0,如果是星期六则标注为:0 0 0 0 0 1 0
# print(features.head(5))

# 标签
labels = np.array(features['actual'])  # 获取标签:features获取actual的标签然后再转换为np.array的格式

# 在特征中去掉标签
features = features.drop('actual', axis = 1)  # 去除features中的actual标签,axis表示沿着行/列去除,axis=0按行计算,axis=1按列计算

# 名字单独保存一下,以备后患
feature_list = list(features.columns)  # 保存features中的columns值,也就是列

# 转换成合适的格式
features = np.array(features)  # 把处理后的features数据也转换为np.array格式

# print(features[0])
# 标准化处理
input_features = preprocessing.StandardScaler().fit_transform(features)
# fit_transform通过数据计算出均值和标准差,再对数据进行标准化处理变换。
# print(input_features[0])

# shape[0]是样本数,也就是行的数据,shape[1]是特征数,也就是列的数据
input_size = input_features.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16  # 一次迭代batch个样本
my_nn = torch.nn.Sequential(
    torch.nn.Linear(input_size, hidden_size),  # 根据输入自动初始化权重参数和偏重值
    torch.nn.ReLU(),  # 激活函数 Sigmoid/ReLU
    torch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')  # 损失函数设置:MSE均方误差
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)
# 优化器设置:Adam,参数为网络中的所有参数以及学习率

# 训练网络
losses = []
# 迭代1000次,epoch = 1000
for i in range(1000):
    batch_loss = []
    # MINI-Batch方法来进行训练
    for start in range(0, len(input_features), batch_size):  # 循环范围为0~样本数,每次循环中间间隔batchs_size
        end = start + batch_size if start + batch_size < len(input_features) else len(input_features)  # 做一个索引是否越界的判断
        # 取得一个batch的数据
        xx = torch.tensor(input_features[start:end], dtype=torch.float, requires_grad=True)
        yy = torch.tensor(labels[start:end], dtype=torch.float, requires_grad=True)
        prediction = my_nn(xx)  # 输入值经过定义的网络运算得到预测值
        loss = cost(prediction, yy)  # 参数为预测值和真实值
        optimizer.zero_grad()  # torch的迭代过程中会累计之前的训练结果,所以在每次迭代中需要清空梯度值
        loss.backward(retain_graph=True)  # 反向传播
        optimizer.step()  # 对所有参数进行更新
        batch_loss.append(loss.data.numpy())

    '''
    # 打印损失
    if i % 100 == 0:
        losses.append(np.mean(batch_loss))
        print(i, np.mean(batch_loss))
    '''

# 预测训练结果
x = torch.tensor(input_features, dtype=torch.float)
# 先将数据转换为tensor格式,因为需要在网络中进行运算
predict = my_nn(x).data.numpy()
# 在网络中运算完成中,再转换为data.numpy格式,因为后续需要进行画图处理

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data={'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)})   # predict是x经过网络训练再转换为np.array的值

# 画图

# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label='actual')  # 参数分别为:横轴,纵轴,曲线颜色,标签值

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label='prediction')  # 参数分别为:横轴,纵轴,曲线颜色,标签值
plt.xticks(rotation=30)  # x轴参数倾斜60°
plt.legend()  # 使上述代码产生效果

# 图名
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')  # x,y轴标签设置
plt.title('Actual and Predicted Values')  # 图名设置

# 保存图片并展示
plt.savefig("result.png")
plt.show()


三.输出结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/380435.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux 36.2@Jetson Orin Nano基础环境构建

Linux 36.2Jetson Orin Nano基础环境构建 1. 源由2. 步骤2.1 安装NVIDIA Jetson Linux 36.2系统2.2 必备软件安装2.3 基本远程环境2.3.1 远程ssh登录2.3.2 samba局域网2.3.3 VNC远程登录 2.4 开发环境安装 3. 总结 1. 源由 现在流行什么&#xff0c;也跟风来么一个一篇。当然&…

【数据结构与算法-初学者指南】【附带力扣原题】队列

&#x1f389;&#x1f389;欢迎光临&#x1f389;&#x1f389; &#x1f3c5;我是苏泽&#xff0c;一位对技术充满热情的探索者和分享者。&#x1f680;&#x1f680; &#x1f31f;特别推荐给大家我的最新专栏《数据结构与算法&#xff1a;初学者入门指南》&#x1f4d8;&am…

保育员答案在哪搜?这4款足够解决问题 #媒体#其他#其他

学会运用各类学习辅助工具和资料&#xff0c;是大学生培养自主学习能力和信息获取能力的重要途径之一。 1.石墨文档 石墨文档(Shimo Docs)是一款强大的在线文档协作工具。它提供了多人实时协作、版本控制、评论和批注等功能&#xff0c;方便学生在学习中进行文档编写、合作项…

完美运营的最新视频打赏系统及附带教程

(购买本专栏可免费下载栏目内所有资源不受限制,持续发布中,需要注意的是,本专栏为批量下载专用,并无法保证某款源码或者插件绝对可用,介意不要购买) 优于市面上95%的打赏系统,与其他打赏系统相比,功能更加强大,完美运营且无bug。支付会调、短链接生成、代理后台、价…

企业内部知识库管理软件的终极指南:如何选择最适合你的工具?

知识库管理软件对于希望提高客户支持和组织效率的公司来说是一个强大的工具。在数字时代&#xff0c;拥有一个可靠的知识库系统对于快速准确地满足客户需求至关重要。在当今的技术条件下&#xff0c;知识库管理软件有很多选择&#xff0c;每个企业都应该仔细评估并选择最适合自…

阿里云幻兽帕鲁服务器有用过的吗?搭建简单啊

玩转幻兽帕鲁服务器&#xff0c;幻兽帕鲁Palworld多人游戏专用服务器一键部署教程&#xff0c;阿里云推出新手0基础一键部署幻兽帕鲁服务器教程&#xff0c;傻瓜式一键部署&#xff0c;3分钟即可成功创建一台Palworld专属服务器&#xff0c;成本仅需26元&#xff0c;阿里云百科…

CTFSHOW命令执行web入门29-54

description: >- 这里就记录一下ctfshow的刷题记录是web入门的命令执行专题里面的题目,他是有分类,并且覆盖也很广泛,所以就通过刷这个来,不过里面有一些脚本的题目发现我自己根本不会笑死。 如果还不怎么知道写题的话,可以去看我的gitbook,当然csdn我也转载了我自己的…

python智慧养老系统—养老信息服务平台vue

本论文中实现的智慧养老系统-养老信息服务平台将以管理员和用户的日常信息维护工作为主&#xff0c;主要涵盖了系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;养老资讯管理&#xff0c;养生有道管理&#xff0c;养老机构管理&#xff0c;系统管理等功能&#x…

【Langchain Agent研究】SalesGPT项目介绍(一)

【2024最全最细LangChain教程-13】Agent智能体&#xff08;二&#xff09;-CSDN博客 之前我们介绍了langchain的agent&#xff0c;其实不难看出&#xff0c;agent是更高级的chain&#xff0c;可以进行决策分析、可以使用工具&#xff0c;今天我们开始开启一些更高阶的课程&…

如何连接ChatGPT?无需科学上网,使用官方GPT教程

随着AI的发展&#xff0c;ChatGPT也越来越强大了。 它可以帮你做你能想到的几乎任何事情&#xff0c;妥妥的生产力工具。 然而&#xff0c;对于许多国内的用户来说&#xff0c;并不能直接使用ChatGPT&#xff0c;不过没关系&#xff0c;我最近发现了一个可以直接免科学上网连…

【大厂AI课学习笔记】【1.5 AI技术领域】(8)文本分类

8,9,10&#xff0c;将分别讨论自然语言处理领域的3个重要场景。 自然语言处理&#xff0c;Natual Language Processing&#xff0c;NLP&#xff0c;包括自然语言识别和自然语言生成。 用途是从非结构化的文本数据中&#xff0c;发掘洞见&#xff0c;并访问这些信息&#xff0…

电脑文件误删除怎么办?8个恢复软件解决电脑磁盘数据可能的误删

您是否刚刚发现您的电脑磁盘数据丢失了&#xff1f;不要绝望&#xff01;无论分区是否损坏、意外格式化或配置错误&#xff0c;存储在其上的文件都不一定会丢失到数字深渊。 我们已经卷起袖子&#xff0c;深入研究电脑分区恢复软件的广阔领域&#xff0c;为您带来一系列最有效…

Qt安装配置教程windows版(包括:Qt5.8.0版本,Qt5.12,Qt5.14版本下载安装教程)(亲测可行)

目录 Qt5.8.0版本安装教程Qt5.8.0版本下载安装 Qt5.12.2版本安装教程下载安装 Qt 5.14.2安装教程下载安装和创建项目 参考视频 QT为嵌入式系统提供了大量的库和可重用组件。 WPS Office&#xff0c;咪咕音乐&#xff0c;Linux桌面环境等都是QT开发的。 Qt5.8.0版本安装教程 Q…

排序算法---堆排序

原创不易&#xff0c;转载请注明出处。欢迎点赞收藏~ 堆排序&#xff08;Heap Sort&#xff09;是一种基于二叉堆数据结构的排序算法。它将待排序的元素构建成一个最大堆&#xff08;或最小堆&#xff09;&#xff0c;然后逐步将堆顶元素与堆的最后一个元素交换位置&#xff0c…

javaEE - 22( 5000 字 Tomcat 和 HTTP 协议入门 -3)

一&#xff1a;Tomcat 1.1 Tomcat 是什么 谈到 “汤姆猫”, 大家可能更多想到的是大名鼎鼎的这个: 事实上, Java 世界中的 “汤姆猫” 完全不是一回事, 但是同样大名鼎鼎. Tomcat 是一个 HTTP 服务器. 前面我们已经学习了 HTTP 协议, 知道了 HTTP 协议就是 HTTP 客户端和…

Java编程构建高效二手交易平台

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

物资捐赠管理系统

文章目录 物资捐赠管理系统一、项目演示二、项目介绍三、系统部分功能截图四、部分代码展示五、底部获取项目&#xff08;9.9&#xffe5;带走&#xff09; 物资捐赠管理系统 一、项目演示 爱心捐赠系统 二、项目介绍 基于springboot的爱心捐赠管理系统 开发语言&#xff1a…

APEX开发过程中需要注意的小细节2

开发时遇到首次获取租户号失败的问题 以为是触发顺序问题&#xff0c;所以设置两个动态操作&#xff0c;一个事件是“更改”&#xff0c;另一个是“单击”&#xff0c; 但还是没有解决&#xff0c; 后来终于找到解决方法:在校验前执行取值 果然成功执行&#xff01; 动态查询年…

获取视频帧图片

在实现了minio文件上传的基础上进行操作 一、编写pom <dependency><groupId>org.jcodec</groupId><artifactId>jcodec</artifactId><version>0.2.5</version> </dependency> <dependency><groupId>org.jcodec<…

30岁还一事无成,怎么办?

前些日子&#xff0c;知乎有一个话题&#xff0c;特别火。 原话是&#xff1a;30岁&#xff0c;如果你还没当上管理层&#xff0c;或者在某个领域取得成就&#xff0c;那你一辈子基本也就这样了。 这句话一出&#xff0c;戳中了许多人的软肋&#xff0c;一时间群情哗然。 理由是…