【深度学习】TensorFlow深度模型构建:训练一元线性回归模型

文章目录

  • 1. 生成拟合数据集
  • 2. 构建线性回归模型数据流图
  • 3. 在Session中运行已构建的数据流图
  • 4. 输出拟合的线性回归模型
  • 5. TensorBoard神经网络数据流图可视化
  • 6. 完整代码

本文讲解:

以一元线性回归模型为例,

  • 介绍如何使用TensorFlow 搭建模型 并通过会话与后台建立联系,并通过数据来训练模型,求解参数, 直到达到预期结果为止。
  • 学习如何使用TensorBoard可视化工具来展示网络图、张量的指标变化、张量的分布情况等。

设给定一批由 y=3x+2生成的数据集( x ,y ),建立线性回归模型h(x)= wx + b ,预测出 w=3 和 b=2。

 

1. 生成拟合数据集

数据集只含有一个特征向量,注意误差项需要满足高斯分布(正态分布),程序使用了NumPy和Matplotlib库。

  • NumPy是Python的一个开源数值科学计算库,可用来存储和处理大型矩阵
  • Matplotlib是Python的绘图库,它可与NumPy一起使用,提供了一种有效的MATLAB开源替代方案。

其代码如下:

# 首先导入3个库
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# 随机产生100个数据点,随机概率符合高斯分布(正态分布)
num_points = 100
vectors_set = []
for i in range(num_points):
    # Draw random samples from a normal (Gaussian) distribution.
    x1 = np.random.normal(0., 0.55)
    y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
    # 坐标点
    vectors_set.append([x1, y1])
# 定义特征向量x
x_data = [v[0] for v in vectors_set]
# 定义标签向量y
y_data = [v[1] for v in vectors_set]


# 按[x_data,y_data]在X-Y坐标系中以打点方式显示,调用plt建立坐标系并将值输出
plt.scatter(x_data, y_data, c='b')
plt.show()

在这里插入图片描述

 

2. 构建线性回归模型数据流图

# 利用TensorFlow随机产生w和b,为了图形显示需要,分别定义名称 myw 和 myb
w = tf.Variable(tf.compat.v1.random_uniform([1], -1., 1.), name='myw')
b = tf.Variable(tf.zeros([1]), name='myb')
# 根据随机产生的w和b,结合上面随机产生的特征向量x_data,经过计算得出预估值
y = w * x_data + b
# 以预估值y和实际值y_data之间的均方差作为损失
loss = tf.reduce_mean(tf.square(y - y_data, name='mysquare'), name='myloss')
# 采用梯度下降法来优化参数
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='mytrain')

 

3. 在Session中运行已构建的数据流图

# global_variables_initializer初始化Variable等变量
sess = tf.compat.v1.Session()
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
print("w=", sess.run(w), "b= ", sess.run(b), sess.run(loss))
# 迭代20次train
for step in range(20):
    sess.run(train)
    print("w=", sess.run(w), "b=", sess.run(b), sess.run(loss))

输出w和b,损失值的变化情况,可以看到损失值从0.42降到了0.001。当然每次拟合的结果都不一致。
在这里插入图片描述

 

4. 输出拟合的线性回归模型

plt.scatter(x_data, y_data, c='b')
plt.plot(x_data, sess.run(w) * x_data + sess.run(b))
plt.show()

在这里插入图片描述

 

5. TensorBoard神经网络数据流图可视化

TensorBoard 是 TensorFlow 的可视化工具包 , 使用者通过TensorBoard可以将代码实现的数据流图以可视化的图形显示在浏览器中,这样方便使用者编写和调试TensorFlow数据流图程序。

首先,将数据流图写入到文件中

# 写入磁盘,以供TensorBoard在浏览器中展示
writer = tf.compat.v1.summary.FileWriter("./mytmp", sess.graph)

运行该代码后就可以将整个神经网络节点信息写入./mytmp目录下。

 
打开终端,执行如下命令

tensorboard --logdir=./tensflow-demo/mytmp

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6007/ (Press CTRL+C to quit)

访问 http://localhost:6007/,如下图生成的神经网络数据流图

在这里插入图片描述

通过添加参数--bind_all 将图暴露给网络。

 

6. 完整代码

# 首先导入3个库
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# 随机产生100个数据点,随机概率符合高斯分布(正态分布)
num_points = 100
vectors_set = []
for i in range(num_points):
    # Draw random samples from a normal (Gaussian) distribution.
    x1 = np.random.normal(0., 0.55)
    y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
    # 坐标点
    vectors_set.append([x1, y1])
# 定义特征向量x
x_data = [v[0] for v in vectors_set]
# 定义标签向量y
y_data = [v[1] for v in vectors_set]

# 按[x_data,y_data]在X-Y坐标系中以打点方式显示,调用plt建立坐标系并将值输出
# plt.scatter(x_data, y_data, c='b')
# plt.show()

tf.compat.v1.disable_v2_behavior()

# 利用TensorFlow随机产生w和b,为了图形显示需要,分别定义名称myw 和 myb
w = tf.Variable(tf.compat.v1.random_uniform([1], -1., 1.), name='myw')
b = tf.Variable(tf.zeros([1]), name='myb')
# 根据随机产生的w和b,结合上面随机产生的特征向量x_data,经过计算得出预估值
y = w * x_data + b
# 以预估值y和实际值y_data之间的均方差作为损失
loss = tf.reduce_mean(tf.square(y - y_data, name='mysquare'), name='myloss')
# 采用梯度下降法来优化参数
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='mytrain')

# global_variables_initializer初始化Variable等变量
sess = tf.compat.v1.Session()
init = tf.compat.v1.global_variables_initializer()
sess.run(init)
print("w=", sess.run(w), "b= ", sess.run(b), sess.run(loss))
# 迭代20次train
for step in range(20):
    sess.run(train)
    print("w=", sess.run(w), "b=", sess.run(b), sess.run(loss))

# 写入磁盘,􏰀供TensorBoard在浏览器中展示
# writer = tf.compat.v1.summary.FileWriter("./mytmp", sess.graph)
#
plt.scatter(x_data, y_data, c='b')
plt.plot(x_data, sess.run(w) * x_data + sess.run(b))
plt.show()

因为运行的是TensorFlow 1.x 系统运行的是 TensorFlow 2.x.,所以运行过程中有两个问题:

1.没有Session

在 TF2 中可以通过 tf.compat.v1.Session() 访问会话

 

2.loss passed to Optimizer.compute_gradients should be a function when eager execution is enabled

在代码前面添加如下代码,屏蔽v2的行为

tf.compat.v1.disable_v2_behavior()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/248976.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据泄露警报:不同行业危机解析与迅软DSE的拯救之道

在如今全球信息数字化不断加速的时代里,数据资料的价值更为突出,根据IBM数据显示,数据泄露的平均成本接近440万美元。一旦泄露可能意味着丢失信息、声誉受损,并可能导致延误和生产力损失。那么不同行业一旦发生了数据泄露将会面临…

Linux部署MySQL5.7和8.0版本 | CentOS和Ubuntu系统详细步骤安装

一、MySQL数据库管理系统安装部署【简单】 简介 MySQL数据库管理系统(后续简称MySQL),是一款知名的数据库系统,其特点是:轻量、简单、功能丰富。 MySQL数据库可谓是软件行业的明星产品,无论是后端开发、…

Redis——02,redis-benchmark 性能测试

redis-benchmark 性能测试 一、benchmark 性能测试。二、参数详解: 一、benchmark 性能测试。 在bin目录下,有一个redis-benchmark 工具,是用来测试性能的。 二、参数详解: http://doc.yaojieyun.com/www.runoob.com/redis/re…

VMP泄露编译的一些注意事项

VMP编译教程 鉴于VMP已经在GitHub上被大佬强制开源,特此出一期编译教程。各位熟悉的可以略过,不熟悉的可以参考一下。 环境(软件) Visual Studio 2015 - 2022 (建议使用VS2019,Qt插件只有这个版本及以上…

Python等比例缩放图片并修改对应的Labelme标注文件(v2.0)

Python等比例缩放图片并修改对应的Labelme标注文件(v2.0) 前言前提条件相关介绍实验环境Python等比例缩放图片并修改对应的Labelme标注文件Json文件代码实现输出结果 前言 此版代码,相较于Python等比例缩放图片并修改对应的Labelme标注文件&a…

原子学习笔记1——阻塞和非阻塞IO

阻塞式 I/O 顾名思义就是对文件的 I/O 操作(读写操作)是阻塞式的,非阻塞式 I/O 同理就是对文件的I/O 操作是非阻塞的。 当对文件进行读操作时,如果数据未准备好、文件当前无数据可读,那么读操作可能会使调用者阻塞&…

编程实际应用实例:洗车店会员管理系统操作教程

一、前言 洗车店在会员管理有时候需要一卡多用,基本也不需要做卡,直接报手机号或车牌号即可完成电子会员卡录入。 下面以 佳易王洗车店会员管理系统软件为例说明, 软件试用版下载或技术支持可以点击下方的官网卡片 如图:这个卡…

[HCTF 2018]WarmUp (代码审计)

打开题目: 好好好。 看看源码: ? source.php 让我看看! 发现还有个文件叫hint,php 看看: 得到目的文件是ffffllllaaaagggg 分析代码: $_REQUEST 变量 $_REQUEST用于收集HTML表单提交的数据&#x…

迅为RK3568开发板使用OpenCV处理图像-ROI区域-位置提取ROI

在图像处理过程中,我们可能会对图像的某一个特定区域感兴趣,该区域被称为感兴趣区域(Region of Interest, ROI)。在设定感兴趣区域 ROI 后,就可以对该区域进行整体操作。 位置提取 ROI 本小节代码在配套资料“iTOP-3…

RocketMQ系统性学习-RocketMQ领域模型及Linux下单机安装

MQ 之间的对比 三种常用的 MQ 对比,ActiveMQ、Kafka、RocketMQ 性能方面: 三种 MQ 吞吐量级别为:万,百万,十万消息发送时延:毫秒,毫秒,微秒可用性:主从,分…

【深度学习】机器学习概述(一)机器学习三要素——模型、学习准则、优化算法

​ 文章目录 一、基本概念二、机器学习的三要素1. 模型a. 线性模型b. 非线性模型 2. 学习准则a. 损失函数1. 0-1损失函数2. 平方损失函数(回归问题)3. 交叉熵损失函数(Cross-Entropy Loss)4. Hinge 损失函数 b. 风险最小化准则1.…

MQTT 介绍与学习 —— 筑梦之路

之前写过的相关文章: MQTT协议(转载)——筑梦之路_mqtt url-CSDN博客 k8s 部署mqtt —— 筑梦之路-CSDN博客 CentOS 7 搭建mqtt服务——筑梦之路_腾讯云宝塔搭 centos 7.9.2009 x86_64 建标准mqtt服务器-CSDN博客 mqtt简介 MQTT&#xff…

tcp/ip协议2实现的插图,数据结构5 (22 - 章)

(103) 103 二二1 协议控制块 结构 file, socket , rawcb , inpcb , tcpcb 之间的联系 (104) (105)

Python:如何将MCD12Q1\MOD11A2\MOD13A2原始数据集批量输出为TIFF文件(镶嵌/重投影/)?

博客已同步微信公众号:GIS茄子;若博客出现纰漏或有更多问题交流欢迎关注GIS茄子,或者邮箱联系(推荐-见主页). 00 前言 之前一段时间一直使用ENVI IDL处理遥感数据,但是确实对于一些比较新鲜的东西IDL并没有python那么好的及时性&…

【Linux】使用官方脚本自动安装 Docker(Ubuntu 22.04)

前言 Docker是一种开源平台,用于开发、交付和运行应用程序。它利用了容器化技术,使开发人员能够将应用程序及其依赖项打包到一个称为Docker容器的可移植容器中。这些容器可以在任何运行Docker的机器上快速、一致地运行,无论是开发环境、测试…

微服务架构之争:Quarkus VS Spring Boot

在容器时代(“Docker时代”),无论如何,Java仍然活着。Java在性能方面一直很有名,主要是因为代码和真实机器之间的抽象层,多平台的成本(一次编写,随处运行——还记得吗?&a…

Ubuntu虚拟机怎么设置静态IP

1 首先先ifconfig看一下使用的是哪个网络接口: 2 编辑 sudo vi /etc/netplan/00-installer-config.yamlnetwork:ethernets:ens33: # 根据您的网络接口进行修改,有的是eth0,有的是ens33,具体看第一步显示的是哪个网络接口addres…

【css】css实现文字两端对齐效果:

文章目录 一、方法1:二、方法2:三、注意: 一、方法1: 给元素设置 text-align: justify;text-align-last: justify;并且加上text-justify: distribute-all-line; 目的是兼容ie浏览器 p{width: 130px;text-align: justify;text-alig…

教育数字化转型 赋能家庭场景自主学习习惯养成

北京市气象台12月12日22时升级发布暴雪橙色预警信号,北京市教委决定自12月13日开始,全市中小学幼儿园采取学生临时居家学习措施。自疫情以来,家庭已经成为另一个学习中心,学校不再是教育的孤岛。 学习方式的变革,数字…

【️什么是分布式系统的一致性 ?】

😊引言 🎖️本篇博文约8000字,阅读大约30分钟,亲爱的读者,如果本博文对您有帮助,欢迎点赞关注!😊😊😊 🖥️什么是分布式系统的一致性 &#xff1f…