【AI知识点】反向传播(Backpropagation)

反向传播(Backpropagation) 是训练神经网络的核心算法,它通过反向逐层计算损失函数对每个权重的梯度,来反向逐层更新网络的权重,从而最小化损失函数。


一、反向传播的基本概念

1. 前向传播(Forward Propagation)

在前向传播中,输入数据从输入层通过隐藏层传递到输出层。网络通过层与层之间的连接(即权重)来计算每个节点的输出,最终生成网络的预测结果。

2. 计算损失(Compute Loss)

将网络的预测输出与真实值进行比较,计算损失函数(如均方误差),用来衡量网络的预测输出与真实值的差距。

3. 反向传播(Backward Propagation)

反向传播的过程主要由链式法则驱动。它通过逐层计算误差对权重的偏导数(梯度),从输出层反向传递到隐藏层,再传递到输入层(与前向传播顺序相反),以反向更新每层的权重,减少预测误差。

  • 前向传播相当于将输入数据从输入层逐步传递到输出层,得到预测结果。
  • 反向传播相当于从输出层开始反向传递误差,更新每一层的权重,使得网络在下次预测时能够减少误差。

4. 权重更新(Weights Update)

使用优化算法(如梯度下降)根据梯度更新权重。使得下一次前向传播时损失函数值减小。


二、反向传播的数学推导

对于一个简单的神经网络,损失函数 L L L 是关于网络输出 y y y 和真实值 t t t 的函数,而网络输出 y y y 又是关于输入 x x x 和权重 w w w 的函数。

通过链式法则,损失函数对权重的梯度可以表示为:

∂ L ∂ w = ∂ L ∂ y ⋅ ∂ y ∂ w \frac{\partial L}{\partial w} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w} wL=yLwy


三、反向传播的图示

在这里插入图片描述
图片来源:https://ai.stackexchange.com/questions/31566/different-ways-to-calculate-backpropagation-derivatives-any-difference

  • 前向传播(蓝色箭头)负责计算输出预测值(Out)和误差(Err)。
  • 反向传播(绿色和红色箭头)从输出误差(Err)开始,将误差逐层传播到隐藏层( a a a)和输入层(X),计算每个权重(W)的梯度,用于后续的权重更新。

四、反向传播的简单计算示例

假设我们有一个简单的两层神经网络:

在这里插入图片描述

  • 输入层(x):一个节点,输入值为 x x x
  • 隐藏层(a):一个节点,激活函数为 Sigmoid 函数。
  • 输出层(y):一个节点,激活函数为线性函数,输出值为 y y y

网络的权重:

  • 输入层到隐藏层的权重: w 1 w_1 w1
  • 隐藏层到输出层的权重: w 2 w_2 w2

给定以下初始条件:

  • 输入 x = 1 x = 1 x=1
  • 目标输出 t = 0 t = 0 t=0
  • 初始权重 w 1 = 0.5 w_1 = 0.5 w1=0.5 w 2 = 0.5 w_2 = 0.5 w2=0.5
  • 学习率 η = 0.1 \eta = 0.1 η=0.1

步骤1:前向传播

  1. 计算隐藏层的输入和输出

z = w 1 ⋅ x = 0.5 ⋅ 1 = 0.5 z = w_1 \cdot x = 0.5 \cdot 1 = 0.5 z=w1x=0.51=0.5

隐藏层的激活输出(使用 Sigmoid 函数):

a = σ ( z ) = 1 1 + e − z = 1 1 + e − 0.5 ≈ 0.6225 a = \sigma(z) = \frac{1}{1 + e^{-z}} = \frac{1}{1 + e^{-0.5}} \approx 0.6225 a=σ(z)=1+ez1=1+e0.510.6225

  1. 计算输出层的输入和输出

y = w 2 ⋅ a = 0.5 ⋅ 0.6225 = 0.3112 y = w_2 \cdot a = 0.5 \cdot 0.6225 = 0.3112 y=w2a=0.50.6225=0.3112


步骤2:计算损失

使用均方误差(MSE)作为损失函数:

L = 1 2 ( y − t ) 2 = 1 2 ( 0.3112 − 0 ) 2 ≈ 0.0484 L = \frac{1}{2}(y - t)^2 = \frac{1}{2}(0.3112 - 0)^2 \approx 0.0484 L=21(yt)2=21(0.31120)20.0484


步骤3:反向传播

  1. 计算输出层对权重 w 2 w_2 w2 的梯度

∂ L ∂ w 2 = ∂ L ∂ y ⋅ ∂ y ∂ w 2 \frac{\partial L}{\partial w_2} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial w_2} w2L=yLw2y

计算各部分:

  • 损失函数对输出 y y y 的导数:

∂ L ∂ y = y − t = 0.3112 − 0 = 0.3112 \frac{\partial L}{\partial y} = y - t = 0.3112 - 0 = 0.3112 yL=yt=0.31120=0.3112

  • 输出 y y y 对权重 w 2 w_2 w2 的导数:

∂ y ∂ w 2 = a = 0.6225 \frac{\partial y}{\partial w_2} = a = 0.6225 w2y=a=0.6225

  • 合并计算梯度:

∂ L ∂ w 2 = 0.3112 × 0.6225 ≈ 0.1938 \frac{\partial L}{\partial w_2} = 0.3112 \times 0.6225 \approx 0.1938 w2L=0.3112×0.62250.1938

  1. 计算隐藏层对权重 w 1 w_1 w1 的梯度

∂ L ∂ w 1 = ∂ L ∂ a ⋅ ∂ a ∂ z ⋅ ∂ z ∂ w 1 \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial a} \cdot \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial w_1} w1L=aLzaw1z

计算各部分:

  • 损失函数对隐藏层输出 a a a 的导数:

∂ L ∂ a = ∂ L ∂ y ⋅ ∂ y ∂ a = ( y − t ) ⋅ w 2 = 0.3112 ⋅ 0.5 = 0.1556 \frac{\partial L}{\partial a} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial a} = (y - t) \cdot w_2 = 0.3112 \cdot 0.5 = 0.1556 aL=yLay=(yt)w2=0.31120.5=0.1556

  • 隐藏层输出 a a a 对输入 z z z 的导数(Sigmoid 函数导数):

∂ a ∂ z = a ( 1 − a ) = 0.6225 ⋅ ( 1 − 0.6225 ) ≈ 0.2350 \frac{\partial a}{\partial z} = a (1 - a) = 0.6225 \cdot (1 - 0.6225) \approx 0.2350 za=a(1a)=0.6225(10.6225)0.2350

  • 输入 z z z 对权重 w 1 w_1 w1 的导数:

∂ z ∂ w 1 = x = 1 \frac{\partial z}{\partial w_1} = x = 1 w1z=x=1

  • 合并计算梯度:

∂ L ∂ w 1 = 0.1556 × 0.2350 × 1 ≈ 0.0365 \frac{\partial L}{\partial w_1} = 0.1556 \times 0.2350 \times 1 \approx 0.0365 w1L=0.1556×0.2350×10.0365


步骤4:更新权重

使用梯度下降法更新权重:

  1. 更新权重 w 2 w_2 w2

w 2 new = w 2 − η ⋅ ∂ L ∂ w 2 = 0.5 − 0.1 × 0.1938 ≈ 0.4806 w_2^{\text{new}} = w_2 - \eta \cdot \frac{\partial L}{\partial w_2} = 0.5 - 0.1 \times 0.1938 \approx 0.4806 w2new=w2ηw2L=0.50.1×0.19380.4806

  1. 更新权重 w 1 w_1 w1

w 1 new = w 1 − η ⋅ ∂ L ∂ w 1 = 0.5 − 0.1 × 0.0365 ≈ 0.4963 w_1^{\text{new}} = w_1 - \eta \cdot \frac{\partial L}{\partial w_1} = 0.5 - 0.1 \times 0.0365 \approx 0.4963 w1new=w1ηw1L=0.50.1×0.03650.4963


步骤5:验证更新后的网络

再次进行前向传播,计算新的输出和损失。

  1. 新的隐藏层输入和输出

z ′ = w 1 new ⋅ x = 0.4963 ⋅ 1 = 0.4963 z' = w_1^{\text{new}} \cdot x = 0.4963 \cdot 1 = 0.4963 z=w1newx=0.49631=0.4963

a ′ = σ ( z ′ ) = 1 1 + e − 0.4963 ≈ 0.6216 a' = \sigma(z') = \frac{1}{1 + e^{-0.4963}} \approx 0.6216 a=σ(z)=1+e0.496310.6216

  1. 新的输出层输出

y ′ = w 2 new ⋅ a ′ = 0.4806 ⋅ 0.6216 ≈ 0.2988 y' = w_2^{\text{new}} \cdot a' = 0.4806 \cdot 0.6216 \approx 0.2988 y=w2newa=0.48060.62160.2988

  1. 新的损失

L ′ = 1 2 ( y ′ − t ) 2 = 1 2 ( 0.2988 − 0 ) 2 ≈ 0.0447 L' = \frac{1}{2}(y' - t)^2 = \frac{1}{2}(0.2988 - 0)^2 \approx 0.0447 L=21(yt)2=21(0.29880)20.0447


结果分析

更新权重后,损失从 0.0484 减少到 0.0447,说明网络朝着最小化损失的方向更新,模型性能有所提升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/888222.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Hbase要点简记

Hbase要点简记 Hbase1、底层架构2、表逻辑结构 Hbase HBase是一个分布式的、列式的、实时查询的、非关系型数据库,可以处理PB级别的数据,吞吐量可以到的百万查询/每秒。主要应用于接口等实时数据应用需求,针对具体需求,设计高效率…

国外电商系统开发-运维系统文件上传-高级上传

如果您要上传文件到10台服务器中,有3台服务器的路径不是一样的,那么在这种情况下您就可以使用本功能,单独执行不一样的路径 点击【高级】上传

常见的图像处理算法:均值滤波----mean filter

一、什么是均值滤波 均值滤波器是一种常见的图像滤波器,是典型的线性滤波算法。其基本原理是用一个给定的窗口覆盖图像中的每一个像素点,将窗口内的像素值求平均值,然后用这个平均值代替原来的像素值。均值滤波器可以去除噪声、平滑图像、减少…

python爬虫案例——处理验证码登录网站(12)

文章目录 前言1、任务目标2、网页分析3、代码编写4、第三方验证码识别平台(超级鹰)前言 我们在爬取某些网站数据时,可能会遇到必须登陆才能获取网页内容的情况,而大部分网站登录都需要输入验证码才能登录成功,所以接下来我将会通过实际案例来讲解如何实现验证码登录网站 1…

嵌入式硬件设计知识详解

✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 💞当前专栏…

C语言初步介绍(初学者,大学生)【上】

1.C语⾔是什么? ⼈和⼈交流使⽤的是⾃然语⾔,如:汉语、英语、⽇语 那⼈和计算机是怎么交流的呢?使⽤ 计算机语⾔ 。 ⽬前已知已经有上千种计算机语⾔,⼈们是通过计算机语⾔写的程序,给计算机下达指令&am…

C#高级编程笔记--字符串和正则表达式

本章的主要内容如下: 创建字符串 格式化表达式 正则表达式​​​​​​​ 1.1 System.String类 System.String是一个类,专门用于存储字符串,允许对字符串进行许多操作。由于这种数据类型非常重要,C#提供了它自己…

‌图片编辑为底片,智能工具助力,创作精彩视觉作品

在当今数字化时代,图像编辑已成为表达创意和美化视觉作品的重要手段。借助智能工具,即使是初学者也能轻松驾驭图片编辑。接下为大家展示图片编辑为底片图片的效果。 1.打开“首助编辑高手”,选择这里“图片批量处理”版块页面上 2.导入保存有…

【AIGC产品经理】面试7家,拿到2个offer,薪资中上水平

Hello,大家好,我是一名不知名的5年B端金融产品经验的产品经理,成功转行AI产品经理,前期面试了北京百度、阿里、理想汽车、百川智能、华为、OPPO等多家大厂面试,但是由于已定居成都,主动终止了后续需要线下的…

uniapp 游戏 - 使用 uniapp 实现的扫雷游戏

0. 思路 1. 效果图 2. 游戏规则 扫雷的规则很简单。盘面上有许多方格,方格中随机分布着一些雷。你的目标是避开雷,打开其他所有格子。一个非雷格中的数字表示其相邻 8 格子中的雷数,你可以利用这个信息推导出安全格和雷的位置。你可以用右键在你认为是雷的地方插旗(称为标…

Web和UE5像素流送、通信教程

一、web端配置 首先打开Github地址:https://github.com/EpicGamesExt/PixelStreamingInfrastructure 找到自己虚幻引擎对应版本的项目并下载下来,我这里用的是5.3。 打开项目找到PixelStreamingInfrastructure-master > Frontend > implementat…

Docker系列-5种方案超详细讲解docker数据存储持久化(volume,bind mounts,NFS等)

文章目录 Docker的数据持久化是什么?1.数据卷(Data Volumes)使用Docker 创建数据卷创建数据卷创建一个容器,将数据卷挂载到容器中的 /data 目录。进入容器,查看数据卷内容停止并重新启动容器,数据卷中的数据…

Vue2电商项目(八) 完结撒花:图片懒加载、路由懒加载、打包的map文件

一、图片懒加载 安装:npm i vue-lazyload1.3 -s (弹幕建议按1.3版本) 引入 // 引入懒加载的图片 import hlw from /assets/hulu.jpg // 引入插件 import VueLazyload from vue-lazyload // 引入插件 Vue.use(VueLazyload, {// 懒加载默认的图…

【Linux-基础IO】磁盘的存储管理详解

磁盘的存储管理 由于一个磁盘中包含了大量的扇区,为了方便管理,我们对磁盘进行了分区,其中每个分区又进一步划分为多个块组(Block Group),每个块组中包含该块组的数据存储情况以及具体的数据 假设有一个8…

前端练习小项目 —— 让图片变得更 “色”

前言:相信读者在学习完了HTML、CSS和JavaScript之后已经想要迫不及待的想找一个小型的项目来练练手,那么这篇文章就正好能满足你的 “需求”。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我的主页秋刀鱼不做梦-CSDN博客 在开始学习…

SpringBoot基础(三):Logback日志

SpringBoot基础系列文章 SpringBoot基础(一):快速入门 SpringBoot基础(二):配置文件详解 SpringBoot基础(三):Logback日志 目录 一、日志依赖二、日志格式1、记录日志2、默认输出格式3、springboot默认日志配置 三、日志级别1、基础设置2、…

Linux中的网络指令:ping、netstat、watch、pidof、xargs

目录 Ping指令 netstat指令 watch指令 pidof指令 xargs指令 Ping指令 功能:检测两台主机间的网络连通性 语法:ping [选项] 目标主机的IP地址 (192.168.1.1)或域名(google.com) 常见选项&#xff1a…

P1010 [NOIP1998 普及组] 幂次方 Python题解

[NOIP1998 普及组] 幂次方 题目描述 任何一个正整数都可以用 2 2 2 的幂次方表示。例如 137 2 7 2 3 2 0 1372^7 2^3 2^0 137272320。 同时约定次方用括号来表示,即 a b a^b ab 可表示为 a ( b ) a(b) a(b)。 由此可知, 137 137 137 可表示…

华为 HCIP-Datacom H12-821 题库 (33)

🐣博客最下方微信公众号回复题库,领取题库和教学资源 🐤诚挚欢迎IT交流有兴趣的公众号回复交流群 🦘公众号会持续更新网络小知识😼 1.VLAN Pool 只要通过一个 SSID 就能够同时支持多个业务 VLAN,从而缩小广播域&#…

draw.io 设置默认字体及添加常用字体

需求描述 draw.io 是一个比较好的开源免费画图软件。但是其添加容器或者文本框时默认的字体是 Helvetica,一般的期刊、会议论文或者学位论文要求的英文字体是 Times New Roman,中文字体是 宋体,所以一般需要在文本字体选项里的下拉列表选择 …