3 tensorflow构建的模型详解

上一篇:2 用TensorFlow构建一个简单的神经网络-CSDN博客

1、神经网络概念

接上一篇,用tensorflow写了一个猜测西瓜价格的简单模型,理解代码前先了解下什么是神经网络。

下面是百度AI对神经网络的解释:

这里不赘述太多概念相关的东西,不理解可先跳过,后面例子看懂了在返回看可能就理解了。

2、密集层

再看看上一篇预测西瓜价格的代码

import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')

history = model.fit(weight, total_cost, epochs=500)

# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

其中这行代码:

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

意思是构建了一个模型,里面添加了一层神经网络,只有一个神经元。(Sequential是顺序的意思,Dense是密集层。)

密集层(也叫全连接层),在神经网络中指的是每个神经元都与前一层的所有神经元相连的层。

举个例子,如下图所示:神经元a1与所有输入层数据相连(X1,X2,X3),其他神经元也一样都与上一层神经元相连。

它们之间的数学关系为:

某个神经元是由连接的上一层神经元分别乘上权重(w),再加上偏差(b)得到,例如计算a1:

权重w的数字下标可以按照顺序命名,比如第一个神经元计算的权重可以为w11、w12……,第二个神经元计算的权重可以为w21、w22……

a2、a3计算以此类推。

3、西瓜预测模型详解

上一篇西瓜费用计算公式 :费用=1.2元/斤*重量+0.5元

即:y=1.2x+0.5

这是一个一元线性回归问题,只有一个自变量x和一个因变量y,机器学习要推算出权重w=1.2, 偏差b=0.5,才能准确预测费用。

下面是实现这个算法的流程:

(1)训练数据准备

西瓜重量 weight=[1, 3, 4, 5, 6, 8]

对应的费用 total_cost=[1.7, 4.1, 5.3, 6.5, 7.7, 10.1]

(2)构建模型

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

  • tf.keras.layers.Dense(1, input_shape=[1]),参数1表示1个神经元,我们只要预测费用y,所以输出层只要一个神经元就可以了(注意:神经元不用包含输入层)。
  • input_shape=[1],表示输入数据的形状为单元素列表,即每个输入数据只有一个值。因为只有一个变量x(西瓜的重量),所以此处输入形状是[1]

该模型的示意图:

可以用model.summary()查看模型摘要,代码如下:

import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])


# 查看模型摘要
model.summary()

运行结果:

可以看到可训练参数有2个,即公式中的w1和b1。

(3)设置损失函数和优化器
model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')
  • mean_squared_error是均方误差,指的是预测值与真实值差值的平方然后求和再平均。公式为:

                    MSE=1/n Σ(P-G)^2 (P为预测值,G为真实值)

  • SGD即随机梯度下降(Stochastic Gradient Descent),是一种迭代优化算法。

(4)设置训练数据
history = model.fit(weight, total_cost, epochs=500)
  • 设置训练数据的特征和标签,在上述代码中分别是西瓜的重量和费用:weight、total_cost
  • 设置训练轮次epochs=500,1个epochs是指使用所有样本训练一次。

(5) 查看训练结果

看下面的训练过程,第8个epoch的时候损失值loss已经很小了,训练轮次不需要设置到500就可以有很好的预测效果了。

刚开始loss很高,使用优化算法慢慢调整了权重,loss值可以很好地衡量我们的模型有多好。

我们把epoch的值调小,看看程序猜测的权重(w)和偏差(b)是多少,以及loss值的计算。

 

代码改动如下:

  •  epochs=5
  • 用model.get_weights()获取程序猜测的权重数据
import numpy as np
import tensorflow as tf

# 西瓜的重量
weight = np.array([1, 3, 4, 5, 6, 8], dtype=float)

# 对应的费用
total_cost = np.array([1.7, 4.1, 5.3, 6.5, 7.7, 10.1], dtype=float)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, input_shape=[1])
])

model.compile(loss=tf.losses.mean_squared_error, optimizer='SGD')

history = model.fit(weight, total_cost, epochs=5)

# 获取权重数据
w = model.get_weights()[0]
b = model.get_weights()[1]

print('w:')
print(w)
print('b: ')
print(b)

# 训练完成后,预测10斤西瓜的总费用
print(model.predict([10]))

运行结果:

训练了5个epoch后,程序猜测w是1.1807659,b为0.33192113

            y=wx+b=1.1807659*10+0.33192113=12.139581

所以预测10斤西瓜的总费用是12.139581

                 


 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/109463.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

linux 系统编程复习07-信号

1 复习目标 了解信号中的基本概念熟练使用信号相关的函数参考文档使用信号集操作相关函数熟练使用信号捕捉函数signal熟练使用信号捕捉函数sigaction熟练掌握使用信号完成子进程的回收 信号介绍 信号的概念 信号是信息的载体,Linux/UNIX 环境下,古老…

【计算机网络】三次握手 四次挥手

目录 1.三次握手 2.四次挥手 3.总结 三次握手和四次挥手是有连接特有的。三次握手,四次挥手指的是TCP有连接特点的中的步骤。建立连接(三次握手),断开连接(四次挥手)。建立连接操作一般都是客户端主动发起,断开连接操作客户端和服务器都可…

Linux shell编程学习笔记17:for循环语句

Linux Shell 脚本编程和其他编程语言一样,支持算数、关系、布尔、字符串、文件测试等多种运算,同样也需要进行根据条件进行流程控制,提供了if、for、while、until等语句。 之前我们探讨了if语句,现在我们来探讨for循环语句。 Li…

rem设置 vscode设置rem 适配 px转rem

1、下载安装 2、 二、 如果代码里面设置 就按代码里面来 -- 20 代码: // 基准大小 const baseSize 20 // 设置 rem 函数 function setRem() {// 当前页面宽度相对于 750 宽的缩放比例,可根据自己需要修改。const scale document.documentElement.clientWidth / …

学会吃亏,也是善良

《六祖坛经》上说:一切福田,都离不开心地。 心田上播下善良的种子,总有一天,会开花结果。 所以,心地善良是一种福祉,是对生命最好的感恩与回报,心存善念,便是最好的修行!…

【Java每日一题】——第四十三题:编程用多态实现打印机.。分为黑白打印机和彩色打印机,不同类型的打印机打印效果不同。(2023.10.30)

🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL&#xff1a…

RK3568-适配at24c04模块

将at24c04模块连接到开发板i2c2总线上 i2ctool查看i2c2总线上都有哪些设备 UU表示设备地址的从设备被驱动占用,卸载对应的驱动后,UU就会变成从设备地址。at24c04模块设备地址 0x50和0x51是at24c04模块i2c芯片的设备地址。这个从芯片手册上也可以得知。A0 A1 A2表示的是模块对…

【鸿蒙软件开发】ArkTS基础组件之Select(下拉菜单)、Slider(滑动条)

文章目录 前言一、Select下拉菜单1.1 子组件1.2 接口参数 1.3 属性1.4 事件1.5 示例代码 二、Slider2.1 子组件2.2 接口参数:SliderStyle枚举说明 2.3 属性2.4 事件SliderChangeMode枚举说明 2.5 示例代码 总结 前言 Select组件:提供下拉选择菜单&#…

地球系统模式CESM

目前通用地球系统模式(Community Earth System Model,CESM)在研究地球的过去、现在和未来的气候状况中具有越来越普遍的应用。CESM由美国NCAR于2010年07月推出以来,一直受到气候学界的密切关注。近年升级的CESM2.0在大气、陆地、海…

Linux高级命令(扩展)

一、find命令 1、find命令作用 在Linux操作系统中,find命令主要用于进行文件的搜索。 2、基本语法 # find 搜索路径 [选项 选项的值] ... 选项说明: -name :根据文件的名称搜索文件,支持*通配符 -type :f代表普通文…

基于小安派AiPi-Eye-S1的Nes游戏机

1.作品展示 作品功能可见以下B站视频 外壳可以使用灰太狼大佬提供的外壳STL文件。在嘉立创三维猴上打印(外壳12元快递6元)。 外壳从以下的帖子中获取: 模型分享 2.作品说明 2.1 硬件部分 硬件上使用到了AiPi-Eye-S1开发板以及3.5寸 240*3…

ElasticSearch快速入门实战

全文检索 什么是全文检索 全文检索是一种通过对文本内容进行全面索引和搜索的技术。它可以快速地在大量文本数据中查找包含特定关键词或短语的文档,并返回相关的搜索结果。全文检索广泛应用于各种信息管理系统和应用中,如搜索引擎、文档管理系统、电子…

安防监控项目---CGI接口的移植和使用

文章目录 前言一、CGI二、CGI的具体移植步骤2.1 cgi源码下载2.2 搭建交叉编译环境2.3 注意事项 三、测试结果总结 前言 书接上期,上期与大家分享的是boa服务器的移植,那么几天要和大家介绍的呢是一款接口,哈哈哈,用起来也是有点难…

NSS刷题 js前端修改 os.path.join漏洞

打算刷一遍nssweb题(任重道远) 前面很简单 都是签到题 这里主要记录一下没想到的题目 [GDOUCTF 2023]hate eat snake 这里 是对js的处理 有弹窗 说明可能存在 alert 我们去看看js 这里进行了判断 如果 getScore>-0x1e9* 我们结合上面 我觉得是6…

人工智能基础_机器学习008_使用正规方程_损失函数进行计算_一元一次和二元一次方程演示_sklearn线性回归演示---人工智能工作笔记0048

自然界很多都是正态分布的,身高,年龄,体重...但是财富不是. 然后我们来看一下这个y = wx+b 线性回归方程. 然后我们用上面的代码演示. 可以看到首先import numpy as np 导入numby 数据计算库 import matplotlib.pyplot as plt 然后导入图形画的库 然后: X = np.linspace(0,…

文件正在使用,操作无法完成。windows查看占用文件的程序

查看占用 tasklist /m IDMShellExt64.dll 映像名称 PID 模块explorer.exe 7452 IDMShellExt64.dll杀死进程 taskkill /f /PID 7452 成功: 已终止 PID 为 7452 的进程。重启explorer explorer

20.3 OpenSSL 对称AES加解密算法

AES算法是一种对称加密算法,全称为高级加密标准(Advanced Encryption Standard)。它是一种分组密码,以128比特为一个分组进行加密,其密钥长度可以是128比特、192比特或256比特,因此可以提供不同等级的安全性…

OpenCV—自动驾驶实时道路车道检测(完整代码)

自动驾驶汽车是人工智能领域最具颠覆性的创新之一。在深度学习算法的推动下,它们不断推动我们的社会向前发展,并在移动领域创造新的机遇。自动驾驶汽车可以去传统汽车可以去的任何地方,并且可以完成经验丰富的人类驾驶员所做的一切。但正确地训练它是非常重要的。自动驾驶汽…

计算机考研 | 2013年 | 计算机组成原理真题

文章目录 【计算机组成原理2013年真题43题-9分】【第一步:信息提取】【第二步:具体解答】 【计算机组成原理2013年真题44题-14分】【第一步:信息提取】【第二步:具体解答】 【计算机组成原理2013年真题43题-9分】 某32位计算机&a…

sql-50练习题6-10

sql练习题6-10题 前言数据库表结构介绍学生表课程表成绩表教师表 0-6 查询"李"姓老师的数量0-7 查询学过"李四"老师授课的同学的信息0-8 查询没学过"李四"老师授课的同学的信息0-9 查询学过编号为"01"并且也学过编号为"02"的…