TensorFlow2实战-系列教程1:回归问题预测

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、环境测试

import tensorflow as tf
import numpy as np
tf.__version__

打印结果

‘2.10.0’

x1 =[[1,9],[3,6]]
x2 = tf.constant(x1)
print(x1)
print(x2)

打印结果:

[[1, 9], [3, 6]]
tf.Tensor( [[1 9] [3 6]], shape=(2, 2), dtype=int32)

2、导包读数据

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.keras
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

features = pd.read_csv('temps.csv')
features.head()

打印结果:

在这里插入图片描述

  • year,moth,day,week分别表示的具体的时间
  • temp_2:前天的最高温度值
  • temp_1:昨天的最高温度值
  • average:在历史中,每年这一天的平均最高温度值
  • actual:这就是我们的标签值了,当天的真实最高温度
  • friend:这一列可能是凑热闹的,你的朋友猜测的可能值,咱们不管它就好了

星期是个文本特性,用onehot转换一下:

features = pd.get_dummies(features)
features.head(5)

在这里插入图片描述

3、标签制作与数据预处理

# 标签
labels = np.array(features['actual'])

# 在特征中去掉标签
features= features.drop('actual', axis = 1)

# 名字单独保存一下,以备后患
feature_list = list(features.columns)

# 转换成合适的格式
features = np.array(features)

打印结果:

(348, 14)

from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)
input_features[0]

打印结果:

array([ 0. , -1.5678393 , -1.65682171, -1.48452388, -1.49443549, -1.3470703 , -1.98891668, 2.44131112, -0.40482045, -0.40961596, -0.40482045, -0.40482045, -0.41913682, -0.40482045])

4、 基于Keras构建网络模型

常用参数:

  • activation:激活函数的选择,一般常用relu
  • kernel_initializer,bias_initializer:权重与偏置参数的初始化方法
  • kernel_regularizer,bias_regularizer:要不要加入正则化
  • inputs:输入,可以自己指定,也可以让网络自动选 units:神经元个数

按顺序构造网络模型:

model = tf.keras.Sequential()
model.add(layers.Dense(16))
model.add(layers.Dense(32))
model.add(layers.Dense(1))
  1. 创建一个执行序列
  2. 添加全连接层,16个神经元
  3. 添加全连接层,32个神经元
  4. 添加全连接层,1个神经元,作为最后的输出

定好优化器和损失函数,然后训练:

model.compile(optimizer=tf.keras.optimizers.SGD(0.001), loss='mean_squared_error')
model.fit(input_features, labels, validation_split=0.25, epochs=10, batch_size=64)
  1. 指定SGD为优化器学习率为0.001,MSE为损失函数
  2. 指定数据和标签然后训练,25%为验证集,10个epochs

打印结果:

Epoch 1/10 5/5 - 0s 33ms/step - loss: 4267.9907 val_loss: 3133.0610
Epoch 2/10 5/5 0s 4ms/step - loss: 1925.8059 - val_loss: 3318.1531
Epoch 3/10 5/5 - 0s 3ms/step - loss: 181.2731 val_loss: 2728.9922
Epoch 4/10 5/5 0s 3ms/step - loss: 104.3410 - val_loss: 2093.8855
Epoch 5/10 5/5 - 0s 3ms/step - loss: 77.6116 - val_loss: 1377.6144
Epoch 6/10 5/5 0s 3ms/step - loss: 73.3877 - val_loss: 1163.6123
Epoch 7/10 5/5 0s 3ms/step - loss: 60.4262 val_loss: 867.4617
Epoch 8/10 5/5 0s 3ms/step - loss: 73.3110 - val_loss: 654.7820
Epoch 9/10 5/5 0s 3ms/step - loss: 36.6109 val_loss: 581.9786
Epoch 10/10 5/5 0s 3ms/step - loss: 56.6764 - val_loss: 383.0244
<keras.callbacks.History at 0x22634a22760>

从打印结果来看,训练集的损失和验证集的损失差距比较大,可能出现过拟合的现象

输入数据:

input_features.shape

打印结果:

(348, 14)

查看网络结构:

model.summary()

打印结果:

Model: “sequential”


Layer (type) Output Shape Param #

dense (Dense) multiple 240


dense_1 (Dense) multiple 544


dense_2 (Dense) multiple 33

Total params: 817
Trainable params: 817
Non-trainable params: 0


5、改初始化方法

model = tf.keras.Sequential()
model.add(layers.Dense(16,kernel_initializer='random_normal'))
model.add(layers.Dense(32,kernel_initializer='random_normal'))
model.add(layers.Dense(1,kernel_initializer='random_normal'))
model.compile(optimizer=tf.keras.optimizers.SGD(0.001), loss='mean_squared_error')
model.fit(input_features, labels, validation_split=0.25, epochs=100, batch_size=64)

部分打印结果:

Epoch 99/100 261/261 0s 42us/sample - loss: 27.9759 - val_loss: 41.2864
Epoch 100/100 261/261 0s 42us/sample - loss: 44.5327 - val_loss: 48.2574

很显然差距消失了

6、加入正则化惩罚项

model = tf.keras.Sequential()
model.add(layers.Dense(16,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.03)))
model.add(layers.Dense(32,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.03)))
model.add(layers.Dense(1,kernel_initializer='random_normal',kernel_regularizer=tf.keras.regularizers.l2(0.03)))
model.compile(optimizer=tf.keras.optimizers.SGD(0.001), loss='mean_squared_error')
model.fit(input_features, labels, validation_split=0.25, epochs=100, batch_size=64)

部分打印结果:

Epoch 99/100 261/261 0s 42us/sample - loss: 26.2268 - val_loss: 20.5562
Epoch 100/100 261/261 0s 42us/sample - loss: 24.3962 - val_loss: 21.1083

很显然结果更好了

7、展示测试结果

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data = {'date': dates, 'actual': labels})

# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]

test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]

test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]

predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)}) 

# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label = 'actual')

# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label = 'prediction')
plt.xticks(rotation = '60'); 
plt.legend()

# 图名
plt.xlabel('Date'); plt.ylabel('Maximum Temperature (F)'); plt.title('Actual and Predicted Values');

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/351014.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深入理解Redis:如何设置缓存数据的过期时间及其背后的机制

目录 Redis 给缓存数据设置过期时间 Redis是如何判断数据是否过期的呢&#xff1f; 过期的数据的删除策略 Redis 内存淘汰机制 Redis 给缓存数据设置过期时间 一般情况下&#xff0c;我们设置保存的缓存数据的时候都会设置一个过期时间。为什么呢&#xff1f; 因为内存是有…

4小时精通MyBatisPlus框架

目录 1.介绍 2.快速入门 2.1.环境准备 2.2.快速开始 2.2.1引入依赖 2.2.2.定义Mapper ​编辑 2.2.3.测试 2.3.常见注解 ​编辑 2.3.1.TableName 2.3.2.TableId 2.3.3.TableField 2.4.常见配置 3.核心功能 3.1.条件构造器 3.1.1.QueryWrapper 3.1.2.UpdateWra…

Redis(八)哨兵机制(sentinel)

文章目录 哨兵机制案例认识异常 哨兵运行流程及选举原理主观下线(Subjectively Down)ODown客观下线(Objectively Down)选举出领导者哨兵选出新master过程 哨兵使用建议 哨兵机制 吹哨人巡查监控后台master主机是否故障&#xff0c;如果故障了根据投票数自动将某一个从库转换为新…

Java 基础知识-File类

大家好我是苏麟 , 今天聊聊File . 资料来自黑马程序员 File类 java.io.File 类是文件和目录路径名的抽象表示&#xff0c;主要用于文件和目录的创建、查找和删除等操作。 构造方法 public File(String pathname) &#xff1a;通过将给定的路径名字符串转换为抽象路径名来创建…

盘古信息IMS OS 数垒制造操作系统+ 产品及生态部正式营运

启新址吉祥如意&#xff0c;登高楼再谱新篇。2024年1月22日&#xff0c;广东盘古信息科技股份有限公司新办公楼层正式投入使用并举行了揭牌仪式&#xff0c;以崭新的面貌、奋进的姿态开启全新篇章。 盘古信息总部位于东莞市南信产业园&#xff0c;现根据公司战略发展需求、赋能…

【双目】基于findChessboardCorners的双目精度评估,可以直接使用

1. 基于findChessboardCorners的双目精度评估 原理&#xff1a; 代码&#xff1a; #include <iostream> #include <opencv2/opencv.hpp>using namespace std; using namespace cv;int main() {// 加载图像auto srcimage imread("/home/oem/data/steroe_p…

Hadoop增加新节点环境配置(自用)

完成Hadoop集群增添一个新的节点配置&#xff08;文中命名为&#xff09;Hadoop106&#xff0c;没有进行继续为该节点分配身份职能的步骤 1.在VMware中安装CentOS 7 新建虚拟机 1.⾸先我们创建⼀个新的虚拟机&#xff0c;也可以点⽂件-新建虚拟机。 2.选择⾃定义&#xff0c…

网页元素圈选

从前面我们已验证配置自动化是可行的&#xff0c;接下来就实现元素选择&#xff0c;当然有了配置化&#xff0c;我们也是可以通过浏览器F12的调试工具去把元素xpath复制出来&#xff08;ps:反正又不是不能用&#xff09;&#xff0c;但是这不是我们最终目的。 其实圈选效果如下…

CSS3如何实现从右往左布局的按钮组(固定间距)

可以通过下方CSS实现&#xff0c;下面的CSS表示按钮从右往左布局&#xff0c;且间距为10px: .right-btn {position: relative;float: right;margin-right: 10px; }类似这种&#xff1a; 这种&#xff1a; 注意&#xff1a; 不能使用right:10px代替margin-right:10px&#x…

聚醚醚酮(Polyether Ether Ketone)PEEK主要应用于哪些行业领域?

聚醚醚酮&#xff08;Polyether Ether Ketone&#xff0c;PEEK&#xff09;广泛应用于以下行业&#xff1a; 1.航空航天业&#xff1a; PEEK常被用于制造航空航天组件&#xff0c;如飞机零部件、航天器构件&#xff0c;因其轻量化、高强度和耐高温性能。 2.汽车工业&#xff1…

滑木块H5小游戏

欢迎来到程序小院 滑木块 玩法&#xff1a;点击木块横着的只能左右移动&#xff0c;竖着的只能上下移动&#xff0c; 移动到箭头的位置即过关&#xff0c;不同关卡不同的木块摆放&#xff0c;快去滑木块吧^^。开始游戏https://www.ormcc.com/play/gameStart/260 html <can…

Django项目搭建

一、创建项目 在命令行中执行代码 $ django-admin startproject mysitedjango-admin 为内部命令startproject 为参数mysite 项目名 备注 避免使用 Python 或 Django 的内部保留字来命名项目。比如说&#xff0c;避免使用像 django (会和 Django 自己产生冲突)或 test (会和 P…

深入浅出 diffusion(3):pytorch 实现 diffusion 中的 U-Net

导入python包 import mathimport torch import torch.nn as nn import torch.nn.functional as F silu激活函数 class SiLU(nn.Module): # SiLU激活函数staticmethoddef forward(x):return x * torch.sigmoid(x) 归一化设置 def get_norm(norm, num_channels, num_groups)…

Java集合-Map接口(key-value)

Map接口的特点&#xff1a;①KV键值对方式存储②Key键唯一&#xff0c;Value允许重复③无序。 Map有四个实现类&#xff1a;1.HashMap类2.LinkedHashMap类3.TreeMap类4.Hashtable类 1.HashMap类&#xff1a; 存储结构&#xff1a;哈希表 数组Node[ ] 链表&#xff08;红黑…

暴力破解

暴力破解工具使用汇总 1.查看密码加密方式 在线网站&#xff1a;https://cmd5.com/ http://www.158566.com/ https://encode.chahuo.com/kali&#xff1a;hash-identifier2.hydra 用于各种服务的账号密码爆破&#xff1a;FTP/Mysql/SSH/RDP...常用参数 -l name 指定破解登录…

BUUCTF--xor1

这题考察的是亦或。查壳&#xff1a; 无壳。看下IDA的流程&#xff1a; 我们看到将用户输入做一个异或操作&#xff0c;然后和一个变量做比较。如果相同则输出Success。这里的知识点就是两次异或会输出原文。因此我们只需要把global再做一次异或就能解出flag。在IDA中按住shift…

Excel 2019 for Mac/Win:商务数据分析与处理的终极工具

在当今快节奏的商业环境中&#xff0c;数据分析已经成为一项至关重要的技能。从市场趋势预测到财务报告&#xff0c;再到项目管理&#xff0c;数据无处不在。而作为数据分析的基石&#xff0c;Microsoft Excel 2019 for Mac/Win正是一个强大的工具&#xff0c;帮助用户高效地处…

鸿蒙HarmonyOS获取GPS精确位置信息

参考官方文档 #1.初始化时获取经纬度信息 aboutToAppear() {this.getLocation() } async getLocation () {try {const result await geoLocationManager.getCurrentLocation()AlertDialog.show({message: JSON.stringify(result)})}catch (error) {AlertDialog.show({message…

RabbitMQ 笔记二

1.Spring 整合RabbitMQ 生产者消费者 创建生产者工程添加依赖配置整合编写代码发送消息 创建消费者工程添加依赖配置整合编写消息监听器 2.创建工程RabbitMQ Producers spring-rabbitmq-producers <?xml version"1.0" encoding"UTF-8"?> <pr…

[TCP协议]基于TCP协议的字典服务器

目录 1.TCP协议简介: 2.TCP协议在Java中封装的类以及方法 3.字典服务器 3.1服务器代码: 3.2客户端代码: 1.TCP协议简介: TCP协议是一种有连接,面向字节流,全双工,可靠的网络通信协议.它相对于UDP协议来说有以下几点好处: 1.它是可靠传输,相比于UDP协议,传输的数据更加可靠…