tensorflow1.基础案例2

前言

在TensorFlow 1.x中实现线性回归通常指的是使用静态图的方式,而在TensorFlow 1.x中使用Eager API实现线性回归是在TensorFlow 1.x的晚期版本中引入的,以提供类似于TensorFlow 2.x的编程体验。以下是两种方式的区别、各自的优点以及对比的作用:

TensorFlow 1.x 静态图方式

实现方式:

  • 首先构建一个计算图,包括数据的输入、模型的定义、损失函数和优化器。
  • 使用tf.placeholder来定义输入数据占位符。
  • 通过tf.Session()来启动会话,并在会话中执行图的操作。

优点:

  • 图优化: 静态图允许在执行前对整个计算图进行优化,可能提高执行效率。
  • 多GPU支持: 更好地支持多GPU环境,适合大规模训练。

TensorFlow 1.x Eager API

实现方式:

  • TensorFlow 1.x在其后期版本中通过tf.contrib.eager模块引入了Eager API。
  • 允许立即执行操作,无需构建静态图,操作的返回值可以直接用于进一步的计算。

优点:

  • 即时反馈: 操作执行后立即返回结果,便于调试和实验。
  • 简化的API: 减少了编写和理解代码的复杂性。

对比作用

  • 编程模型: 静态图需要事先定义所有的操作,而Eager API允许逐步执行操作。
  • 易用性: Eager API通常更易于学习和使用,特别是对于习惯了Python即时执行特性的程序员。
  • 灵活性: Eager API提供了更高的灵活性,可以动态地修改和执行操作。
  • 性能: 静态图可能在性能上更有优势,尤其是在需要执行大量预定义计算的情况下。

实际应用

  • 研究和开发: Eager API适合快速迭代和实验,因为它可以即时看到操作结果。
  • 生产环境: 静态图可能更适合生产环境,特别是当模型训练和推断需要高性能和稳定性时。

总结

尽管TensorFlow 1.x的Eager API为1.x版本带来了更现代的编程体验,但它实际上是TensorFlow 2.x中Eager执行的预演。TensorFlow 2.x默认启用Eager执行,提供了更简洁和Pythonic的API,同时保持了向后兼容性。因此,对于新项目,推荐使用TensorFlow 2.x,它结合了1.x的静态图性能和Eager API的易用性。

2.1线性回归

所需的库:

tensorflow    1.12.0
numpy         1.19.5
matplotlib    3.3.4

代码:

# 线性回归
import tensorflow as tf
import numpy
import matplotlib.pyplot as plt
rng = numpy.random

# 参数
learning_rate = 0.01
training_epochs = 1000
display_step = 200

# 训练数据
train_X = numpy.asarray([3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
                         7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1])
train_Y = numpy.asarray([1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
                         2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3])
n_samples = train_X.shape[0]

# tf 图输入
X = tf.placeholder('float')
Y = tf.placeholder('float')
W = tf.Variable(rng.randn(), name='weight')
b = tf.Variable(rng.randn(), name='bias')

# 创建一个线性模型
pre = tf.add(tf.multiply(X, W), b)

# 均方误差损失
cost = tf.reduce_sum(tf.pow(pre - Y, 2))/(2 * n_samples)
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 初始化变量(分配默认值)
init = tf.global_variables_initializer()

# 开始训练
with tf.Session() as sess:  # 创建并进入一个TensorFlow会话的上下文管理器
    sess.run(init)  # 运行初始化器init
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):  # 遍历训练数据集,每次迭代提供一对特征x和标签y
            sess.run(optimizer, feed_dict={X: x, Y: y})  # 运行优化器操作,如梯度下降步骤,来更新模型的参数。

        # 记录损失
        if (epoch + 1) % display_step == 0:
            c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print(f"epoch{epoch + 1} cost= {c}, W = {sess.run(W)}, b = {sess.run(b)}")

    print("optimization finished!")
    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
    print(f"training cost={training_cost}, W = {sess.run(W)}, b = {sess.run(b)}")

    # 画图
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(train_X, sess.run(W) * train_X + sess.run(b), label='Fitted line')
    plt.legend()
    plt.show()

输出:

epoch200 cost= 0.08749574422836304, W = 0.30706509947776794, b = 0.3880200982093811
epoch400 cost= 0.08340831100940704, W = 0.29456469416618347, b = 0.47794750332832336
epoch600 cost= 0.08090882748365402, W = 0.28478333353996277, b = 0.5483137965202332
epoch800 cost= 0.07938076555728912, W = 0.27712881565093994, b = 0.6033796668052673
epoch1000 cost= 0.07844720780849457, W = 0.27114012837409973, b = 0.6464620232582092
optimization finished!
training cost=0.07844720780849457, W = 0.27114012837409973, b = 0.6464620232582092

在这里插入图片描述

2.2 使用 TensorFlow 的 Eager API 实现线性回归

代码:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# 设置Eager API
tf.enable_eager_execution()
tfe = tf.contrib.eager

# 训练数据
train_X = [3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
           7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1]
train_Y = [1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221,
           2.827, 3.465, 1.65, 2.904, 2.42, 2.94, 1.3]
n_samples = len(train_X)

# 参数设置
learning_rate = 0.01
display_step = 200
num_steps = 1000

# 权重与偏置
W = tfe.Variable(np.random.randn())
b = tfe.Variable(np.random.randn())

# 线性回归
def linear_regression(inputs):
    return inputs * W + b

# 均方误差
def mean_square_fn(model_fn, inputs, labels):
    return tf.reduce_sum(tf.pow(model_fn(inputs) - labels, 2)) / (2 * n_samples)

# 梯度下降优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)

# 计算梯度
grad = tfe.implicit_gradients(mean_square_fn)

# 初始损失
print(f"初始损失:{mean_square_fn(linear_regression, train_X, train_Y)}, W = {W.numpy()}, b = {b.numpy()}.")

# 开始训练
for step in range(num_steps):
    optimizer.apply_gradients(grad(linear_regression, train_X, train_Y))
    if (step + 1) % display_step == 0 or step == 0:
        print(f"Epoch:{step + 1}, 损失:{mean_square_fn(linear_regression, train_X, train_Y)}, W = {W.numpy()}, b = {b.numpy()}.")

plt.plot(train_X, train_Y, 'ro', label='初始数据')
plt.plot(train_X, np.array(W * train_X + b), label='拟合曲线')
plt.legend()
plt.show()

输出结果:

初始损失:25.554563522338867, W = 1.414129614830017, b = 0.1566573679447174.
Epoch:1, 损失:7.765857696533203, W = 0.9393506050109863, b = 0.09066701680421829.
Epoch:200, 损失:0.10071783512830734, W = 0.33907845616340637, b = 0.17886608839035034.
Epoch:400, 损失:0.09156356751918793, W = 0.32022035121917725, b = 0.3125614523887634.
Epoch:600, 损失:0.08593194931745529, W = 0.3054291903972626, b = 0.41742414236068726.
Epoch:800, 损失:0.0824674665927887, W = 0.2938278913497925, b = 0.4996721148490906.
Epoch:1000, 损失:0.08033612370491028, W = 0.2847285568714142, b = 0.5641824007034302.

在这里插入图片描述

小tips

权重(W)和偏置(b)的初始化已经在创建这两个变量时完成了。这是通过使用tf.Variable来实现的,并且使用numpy.random.randn()生成了随机数作为它们的初始值。:

W = tf.Variable(rng.randn(), name='weight')
b = tf.Variable(rng.randn(), name='bias')

如果想给权重和偏置赋予初始值,可以这样:

W = tf.Variable(0.4, name='weight')
b = tf.Variable(0.7, name='bias')

结果对比:

随机数赋初值:

training cost=0.07943815737962723, W = 0.22219787538051605, b = 0.998549222946167

手动赋初值:

training cost=0.07706085592508316, W = 0.25450703501701355, b = 0.7661195993423462

赋初值的优点:

  • 能够精确控制模型参数的起始点,有助于调试和验证模型的行为;
  • 从特定的值开始训练可以提高模型的稳定性和收敛速度;
  • 使用固定的初始值时,实验结果是可重复的,这对于科学研究和调试至关重要。
  • 适当的初始化可以减少过拟合的风险
  • 在超参数调优过程中,固定的初始化策略可以确保不同设置之间的比较是公平的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/792936.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux下fcitx框架输入法输入中文标点时为半角(英文)标点符号的解决

目录 引入解决1.打开fcitx设置2.打开全局配置3. 随便找个可以输入地方敲下快捷键 总结 本文由Jzwalliser原创,发布在CSDN平台上,遵循CC 4.0 BY-SA协议。 因此,若需转载/引用本文,请注明作者并附原文链接,且禁止删除/修…

轻松搭建系统,让每个故事都精彩绽放!

"轻松搭建系统,让每个故事都精彩绽放!" 这句话传递了一个核心理念,即通过简化、高效的系统搭建过程,让每一个创意故事都能以最佳状态呈现给观众,实现其独特魅力和价值的最大化。 1、模块化设计:系…

CV06_Canny边缘检测算法和python实现

1.1简介 Canny边缘检测算法是计算机视觉和图像处理领域中一种广泛应用的边缘检测技术,由约翰F坎尼(John F. Canny)于1986年提出。它是基于多级处理的边缘检测方法,旨在实现以下三个优化目标: 好的检测:尽…

机器学习第四十七周周报 CF-LT

文章目录 week47 CF-LT摘要Abstract1. 题目2. Abstract3. 网络结构3.1 CEEMDAN(完全自适应噪声集合经验模态分解)3.2 CF-LT模型结构3.3 SHAP 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程 5. 结论6.代码复现小结参考文献 week47 CF-LT 摘要 本周…

以太网连接、本地连接、宽带连接和无线WLAN连接有什么不同?

电脑上的以太网连接、本地连接、宽带连接和无线WLAN连接在功能和实现方式上存在一定的区别。以下是对这四种连接方式的详细解析: 一、以太网连接与本地连接 1. 定义与关系 以太网连接:以太网是一种广泛应用的局域网(LAN)技术&a…

【YOLOv8】 用YOLOv8实现数字式工业仪表智能读数(二)

上一篇圆形表盘指针式仪表的项目受到很多人的关注,咱们一鼓作气,把数字式工业仪表的智能读数也研究一下。本篇主要讲如何用YOLOV8实现数字式工业仪表的自动读数,并将读数结果进行输出,若需要完整数据集和源代码可以私信。 目录 &…

排序(二)——快速排序(QuickSort)

欢迎来到繁星的CSDN,本期内容包括快速排序(QuickSort)的递归版本和非递归版本以及优化。 一、快速排序的来历 快速排序又称Hoare排序,由霍尔 (Sir Charles Antony Richard Hoare) ,一位英国计算机科学家发明。霍尔本人是在发现冒泡排序不够快…

MQTT协议网关解决方案及实施简述-天拓四方

MQTT协议网关是一个中间件,负责接收来自不同MQTT客户端的消息,并将这些消息转发到MQTT服务器;同时,也能接收来自MQTT服务器的消息,并将其转发给相应的MQTT客户端。MQTT协议网关的主要功能包括协议转换、消息过滤、安全…

ImportError: DLL load failed while importing cv2解决方案

系统是 server 2012 r2 datacenter 背景:在window10系统上采用PyInstaller打包python310版本程序后,在server 2012 r2 datacenter运行报错ImportError: DLL load failed while importing cv2,最后解决方案参考了一篇文章下评论修改成功。 原…

Qt http网络编程

学习目标:Qt HTTP网络编程 学习内容 1、Http就是超文本传输协议(Hypertext Transfer Protocol)的缩写,它定义了浏览器和网页服务器之间的通信规范。是一个简单的请求一响应协议,它通常运行在 TCP 之上。 作用:规定 WWW 服务器与浏览器之间信息传递规范…

软件测试常见面试题汇总(2024版)

一、常见的面试题汇总 1、你做了几年的测试、自动化测试,说一下 selenium 的原理是什么? 我做了五年的测试,1年的自动化测试; selenium 它是用 http 协议来连接 webdriver ,客户端可以使用 Java 或者 Python 各种编…

软件测试想转职有适合的岗位吗?

软件测试被有些人看做技术含量低,但是软件测试实际上是万金油行业,如果你不是在很大的公司做的软件测试,相比你做的工作是很杂的,比如软件测试找bug,你的主业,帮着产品经理整理需求,帮着项目经理…

微软开源项目GraphRAG——基于知识图谱的RAG简介

前言 在大型语言模型(LLM)的前沿研究中,一个核心挑战与机遇并存的领域是扩展它们的能力,以解决超出其训练数据范畴的问题。这不仅要求模型在面对全新数据时仍能保持卓越表现,还意味着开辟了全新的数据分析可能性&…

【C++】C++ 汽车租赁管理系统(源码+论文)【500+行代码】【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…

CAN总线实战项目:使用STM32和PCAN-View实现数据采集与监控系统(附完整代码)

摘要: 本文深入浅出地介绍CAN(Controller Area Network,控制器局域网络)总线协议,涵盖其基础概念、报文帧格式、仲裁机制、错误处理等关键知识。同时,文章结合STM32平台,从硬件设计、软件开发到实战案例&am…

【益起童行】为“来自星星的孩子”点亮希望之光

在未来的日子里, 我期望每一个孩子都能得到优质的干预治疗,让他们在未来能够过上正常、快乐的生活。 我也期望每一个家庭都能战胜困境,迎来美好。 作为社会的一份子,我愿意为这繁华人世贡献出自己微不足道但却真挚的力量&#xff…

24暑假计划

暑假计划: 1.从明天起开始将C语言的部分补充完整,这部分的预计在7月24日前完成 2.由于之前的文章内容冗余,接下来进行C语言数据结构的重新编写和后面内容的补充预计8月10号前完成 3.后续开始C的初级学习

新加坡很火的slots游戏代投Facebook广告新流量趋势

新加坡很火的slots游戏代投Facebook广告新流量趋势 在新加坡这片充满活力的土地上,Slots游戏以其独特的魅力和吸引力,迅速成为了许多玩家的心头好。而Facebook,作为全球最大的社交媒体平台之一,为Slots游戏的推广提供了得天独厚的…

element-plus 按需导入问题 404等问题

场景 新开一个项目,需要用element-plus这个ui库,使用按需引入。 这是我项目的一些版本号 "element-plus": "^2.7.6","vue": "^3.2.13","vue-router": "^4.0.3",过程(看解决方法…

【MySQL】常见的MySQL日志都有什么用?

MySQL日志的内容非常重要,面试中经常会被问到。同时,掌握日志相关的知识也有利于我们理解MySQL 底层原理,必要时帮助我们排查解决问题。 MySQL中常见的日志类型主要有下面几类(针对的是InnoDB 存储引擎): 错误日志(error log):对 MySQL 的启…