吴恩达deeplearning.ai:Tensorflow训练一个神经网络

以下内容有任何不理解可以翻看我之前的博客哦:吴恩达deeplearning.ai
在之前的博客中。我们陆续学习了各个方面的有关深度学习的内容,今天可以从头开始训练一个神经网络了。

Tensorflow训练神经网络模型

我们使用之前用过的例子:
在这里插入图片描述
这个神经网络有三层,第一层拥有25个神经元,第二层15个神经元,第三层为最终输出层。
现在提供一个训练集X,一个标签Y,该如何通过代码的形式来表现呢?

#1导入工具包
import tensrflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

#2创建三个层并让Tensorflow按照顺序将几个层串联起来
  model = Sequential([
    Dense(units = 25, activation = 'sigmoid')
    Dense(units = 15, activation = 'sigmoid')
    Dense(units = 1, activation = 'sigmoid')
                     ])
 #3引入工具包,并且让损失函数使用分类交叉熵的形式
from tensorflow.keras.losses import
BinaryCrossentropy
  model.compile(loss = BinaryCrossentropy())

#调用拟合函数,epoch代表训练次数
  model.fit(X, Y, epochs=100)

模型中的一些细节讲解

框架相关

让我们先复习一下之前的内容,如何实现逻辑回归的:
第一步,如何在给定输入特征X和参数W,b的情况下计算输出(定义模型),我们这里经常使用的是sigmoid函数。
第二步,指定损失函数与成本函数
第三步,训练模型,最小化J(w,b)
让我们在训练神经网络的背景下来看看这几步:

#1导入工具包
import tensrflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

#2创建三个层并让Tensorflow按照顺序将几个层串联起来
  model = Sequential([
    Dense(units = 25, activation = 'sigmoid')
    Dense(units = 15, activation = 'sigmoid')
    Dense(units = 1, activation = 'sigmoid')
                     ])

这几段代码说明了神经网络的整个架构体系,告诉你第一层有25个神经元,第二层有15个神经元,第三层一个,采用的激活函数均为sigmoid。

损失函数相关

再写一遍 损失函数的一般数学表达式:
J ( W , B ) = 1 m ∑ L ( f ( x ( i ) , y ( i ) ) J(W,B) = \frac{1}{m}\sum L(f(x^{(i)},y^{(i)}) J(W,B)=m1L(fx(i),y(i))

 #3引入工具包,并且让损失函数使用分类交叉熵的形式
from tensorflow.keras.losses import
BinaryCrossentropy
  model.compile(loss = BinaryCrossentropy())

这个名叫keras的工具包其实是和tensorflow是完全不同的两个项目开发的,只是最后合入了tensorflow,所有它的工具包需要你单独import。另外,由于工具包的种类真的很多,所以不知道工具包的名字和使用方法时可以上网查找哦。
我们在之前的博客中,曾经学习过二元交叉熵(这是统计学上的叫法),二元的意思是说明这是个布尔值,要么为1要么为0.只是在之前的博客中不叫这个名字,而是为了能够在一个式子之中写出价代价函数:
L ( f ( x ) , y ) = − y l o g ( f ( x ) ) − ( 1 − y ) l o g ( ( 1 − f ( x ) ) L(f(x),y) = -ylog(f(x)) - (1-y)log((1-f(x)) L(f(x),y)=ylog(f(x))(1y)log((1f(x))
在制定了损失函数之后,Tensorflow就知道了你是希望最小化m个训练的平均值。
如果你是想解决其它类型的问题例如回归问题,你可以给tensorflow指定其它种类的损失函数:

from tensorflow.keras.losses import MeanSquareError
model.compile(loss = MeanSquareError())

这是最小化平方误差损失的损失函数。

梯度下降

梯度下降时,你需要重复公式:
w = w − α ∂ ∂ w j J ( w , b ) b = b − α ∂ ∂ b j J ( w , b ) w = w - \alpha\frac{\partial}{\partial w_j}J(w,b)\\ b = b - \alpha\frac{\partial}{\partial b_j}J(w,b) w=wαwjJ(w,b)b=bαbjJ(w,b)

#调用拟合函数,epoch代表训练次数
  model.fit(X, Y, epochs=100)

Tensorflow使用的是一种叫做反向传播的算法来计算这些偏导数项,只是在函数model.fit中完成的,并告诉它这样迭代100次。

很明显我们现在的代码严重依赖于Tensorflow库,随着技术的发展,大部分工程师都会使用库而非自己重头编起。现在你已经了解了如何自己训练一个神经网络了,在接下来的博客中我们讲讲到一些你可以改变的地方,使得你的神经网络更加强大。
为了给读者你造成不必要的麻烦,博主的所有视频都没开仅粉丝可见,如果想要阅读我的其他博客,可以点个小小的关注哦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/408612.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python中的functools模块详解

大家好,我是海鸽。 函数被定义为一段代码,它接受参数,充当输入,执行涉及这些输入的一些处理,并根据处理返回一个值(输出)。当一个函数将另一个函数作为输入或返回另一个函数作为输出时&#xf…

JAVA算法和数据结构

一、Arrays类 1.1 Arrays基本使用 我们先认识一下Arrays是干什么用的,Arrays是操作数组的工具类,它可以很方便的对数组中的元素进行遍历、拷贝、排序等操作。 下面我们用代码来演示一下:遍历、拷贝、排序等操作。需要用到的方法如下 public…

26.HarmonyOS App(JAVA)列表对话框

列表对话框的单选模式: //单选模式 // listDialog.setSingleSelectItems(new String[]{"第1个选项","第2个选项"},1);//单选 // listDialog.setOnSingleSelectListener(new IDialog.ClickedListener() { // Override …

互联网加竞赛 机器视觉opencv答题卡识别系统

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 答题卡识别系统 - opencv python 图像识别 该项目较为新颖,适合作为竞赛课题方向,学长非常推荐! 🥇学长这里给一个题目综合评分(每项满分5分…

C++中的左值和右值

目录 一. 左值和右值的概念 1. 左值 1.1 可修改的的左值 1.2 不可修改的左值 右值 二. 左值引用和右值引用 1. 左值引用 2. 右值引用 主要用途 1. 移动语义 2. 完美转发 2.1 引用折叠 2.2 std::forward 一. 左值和右值的概念 什么是左值和右值 1. 左值 左值是一个表示…

Unity3D 使用 Proto

一. 下载与安装 这里下载Google Protobuff下载 1. 源码用来编译CSharp 相关配置 2. win64 用于编译 proto 文件 二. 编译 1. 使用VS 打开 2. 点击最上面菜单栏 工具>NuGet 包管理器>管理解决方案的NuGet 管理包 版本一定要选择咱们一开始下载的对应版本否则不兼容&am…

使用免费的L53巧解Freenom域名失效问题

进入2月份以来,不少小伙伴纷纷收到Freenom提供的域名失效,状态由正常变成了Pending。 失效后,域名无法使用,免费的午餐没有了,而现在域名的价格也是水涨船高,真是XXX。很多做外贸的小伙伴表示 难 啊&#x…

树状数组与线段树<2>——线段树初步

这个系列终于更新了(主要因为树状数组初步比较成功) 话不多说,切入正题。 什么是线段树? 线段树是一种支持单点修改区间查询(树状数组也行) and 区间修改单点查询(树状数组不行) and 区间修改区间查询(树状数组更不行)的高级数据结构,相当…

Chiplet技术与汽车芯片(二)

目录 1.回顾 2.Chiplet的优势 2.1 提升芯片良率、降本增效 2.2 设计灵活,降低设计成本 2.3 标准实行,构建生态 3.Chiplet如何上车 1.回顾 上一篇,我们将来芯粒到底是什么东西,本篇我们来看芯粒技术的优势,以及它…

5.1 Ajax数据爬取之初介绍

目录 1. Ajax 数据介绍 2. Ajax 分析 2.1 Ajax 例子 2.2 Ajax 分析方法 (1)在网页页面右键,检查 (2)找到network,ctrl R刷新 (3)找 Ajax 数据包 (4)…

多线程相关(4)

线程安全-下 使用层面锁优化减少锁的时间:减少锁的粒度:锁粗化:使用读写锁:使用CAS: 系统层面锁优化自适应自旋锁锁消除锁升级偏向锁轻量级锁重量级锁 ThreadLocal原理ThreadLocal简介原理ThreadLocal内存泄漏 HashMap…

VMware使用虚拟机,开启时报错:无法连接虚拟设备 0:0,因为主机上没有相应的设备。——解决方法

检查虚拟机配置文件并确保物理设备已正确连接。 操作: 选中虚拟机,打开设置,点击CD/DVD。在连接处选择使用ISO镜像文件

fpga_硬件加速引擎

一 什么是硬件加速引擎 硬件加速引擎,也称硬件加速器,是一种采用专用加速芯片/模块替代cpu完成复杂耗时的大算力操作,其过程不需要或者仅需要少量cpu参与。 二 典型的硬件加速引擎 典型的硬件加速引擎有GPU,DSP,ISP&a…

【二分查找】【浮点数的二分查找】【二分答案查找】

文章目录 前言一、二分查找(Binary Search)二、浮点数的二分查找三、二分答案总结 前言 今天记录一下基础算法之二分查找 一、二分查找(Binary Search) 二分查找(Binary Search)是一种在有序数组中查找目…

1 Nacos数据持久化方式

Nacos 支持两种数据持久化方式,一种是利用内置的数据库,另一种是利用外置的数据源。 1、内置数据库支持 Nacos 默认内置了一些数据存储解决方案,如内嵌的 Derby 数据库。 这种内置方式主要用于轻量级或测试环境。 2、外置数据库支持 对于生…

【RN】学习使用 Reactive Native内置UI组件

简言 当把导航处理好后,就可以学习使用ui组件了(两者没有先后关系,个人习惯)。 在 Android 和 iOS 开发中,一个视图是 UI 的基本组成部分:屏幕上的一个小矩形元素、可用于显示文本、图像或响应用户输入。甚…

如何使用逻辑回归处理多标签问题?

逻辑回归处理多分类 1、背景描述2、One vs One3、One vs Rest4、从Sigmoid到Softmax的推导 1、背景描述 逻辑回归本身只能用于二分类问题,如果实际情况是多分类的,那么就需要对模型进行一些改动。下面介绍三种常用的将逻辑回归用于多分类的方法 2、One …

目标跟踪之KCF详解

High-Speed Tracking with Kernelized Correlation Filters 使用内核化相关滤波器进行高速跟踪 大多数现代跟踪器的核心组件是判别分类器,其任务是区分目标和周围环境。为了应对自然图像变化,此分类器通常使用平移和缩放的样本补丁进行训练。此类样本集…

用Python实现创建十二星座数据分析图表

下面小编提供的代码中,您已经将pie.render()注释掉,并使用了pie.render_to_file(十二星座.svg)来将饼状图渲染到一个名为十二星座.svg的文件中。这是一个正确的做法,如果您想在文件中保存图表而不是在浏览器中显示它。 成功创建图表&#xf…

嵌入式软件分层设计的思想分析

“嵌入式开发&#xff0c;点灯一路发” 那今天我们就以控制LED闪烁为例&#xff0c;来聊聊嵌入式软件分层: ——————————— | | | P1.1 |-----I<|--------------<| | | | P2.1 |-------------/ ---------…