详解反向传播(BP)算法

文章目录

      • what(是什么)
      • where(用在哪)
      • How(原理&&怎么用)
          • 原理以及推导过程
          • pytorch中的反向传播

what(是什么)

反向传播算法(Backpropagation)是一种用于训练人工神经网络的常见方法。它通过计算网络预测与实际结果之间的误差,然后反向传播这个误差来调整网络中每个权重的值,从而逐步优化网络的学习过程

在这里插入图片描述

where(用在哪)

绝大多数的神经网络都会使用反向传播算法进行网络权重以及阈值的更新,简单列举部分典型的使用场景如下

反向传播算法
前馈神经网络
多层感知机
卷积神经网络
循环神经网络
深度神经网络

How(原理&&怎么用)

原理以及推导过程

下面重点介绍反向传播算法的推导流程

在这里插入图片描述

假设有以上简单的神经网路模型,分为输入层、隐藏层、输出层。其中隐藏层包括4个神经元、输出层包括2个神经元。
假设输出层的两个神经元为 y 1 y_1 y1 y 2 y_2 y2,其激活阈值分别为 β \beta β γ \gamma γ,两个神经元的输入分别为 y 1 i n y_{1in} y1in y 2 i n y_{2in} y2in,输出分别为 y 1 ^ \hat{y_1} y1^ y 2 ^ \hat{y_2} y2^
假设隐藏层四个神经元为 h 1 h_1 h1 h 2 h_2 h2 h 3 h_3 h3 h 4 h_4 h4,其中 h 1 h_1 h1的激活阈值为 δ \delta δ,神经元 h 1 h_1 h1的输入值为 h i n h_{in} hin,输出值为 h o u t h_{out} hout
假设输入层两个神经元为 x 1 x_1 x1 x 2 x_2 x2,其中神经元 x 1 x_1 x1的输出为 x o u t x_{out} xout
假设神经元 x 1 x_1 x1到神经元 h 1 h_1 h1的连接权重为 W 11 W_{11} W11,神经元 h 1 h_1 h1到神经元 y 1 y_1 y1 y 2 y_2 y2的连接权重分别为 W 21 W_{21} W21 W 22 W_{22} W22
假设神经元的激活函数为sigmoid函数,sigmoid激活函数的表达式:
f ( x ) = 1 1 − e − x f(x)=\frac{1}{1-e^{-x}} f(x)=1ex1
该激活函数有一个非常好的性质:
f ′ ( x ) = f ( x ) ( 1 − f ( x ) ) f'(x)=f(x)(1-f(x)) f(x)=f(x)(1f(x))
下面,详细介绍连接权重 W W W以及激活阈值的更新过程。
首先,给出 W 21 W_{21} W21以及 β \beta β的更新公式,其中, W 21 W_{21} W21更新公式为:
W 21 = W 21 + η ∗ Δ W 21 W_{21}=W_{21}+\eta*\Delta W_{21} W21=W21+ηΔW21
同理, β \beta β更新公式为:
β = β + η ∗ Δ β \beta=\beta+\eta*\Delta \beta β=β+ηΔβ

在以上公式中,只有 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ未知,需要计算。而已知的是样本,也就是 ( x , y ) (x,y) (x,y),那么我们将通过样本数据来表达出上述 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ
根据反向传播算法, Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ分别为最终的误差对 W 21 W_{21} W21以及 β \beta β的偏导数。假设采用的损失函数为:
L o s s = 1 2 ( y 1 − y 1 ^ ) 2 + 1 2 ( y 2 − y 2 ^ ) 2 Loss=\frac{1}{2}(y_1-\hat{y_1})^2+\frac{1}{2}(y_2-\hat{y_2})^2 Loss=21(y1y1^)2+21(y2y2^)2
扩展到输出层有k个神经元的情况:
L o s s = 1 2 Σ 1 k ( y i − y i ^ ) 2 Loss=\frac{1}{2}\Sigma_1^k(y_i-\hat{y_i})^2 Loss=21Σ1k(yiyi^)2
而从输出端看,能得到以下表达式:
y 1 ^ = f ( y 1 i n − β ) = f ( W 21 h o u t − β ) \hat{y_1}=f(y_{1in}-\beta)=f(W_{21}h_{out}-\beta) y1^=f(y1inβ)=f(W21houtβ)
y 1 ^ \hat{y_1} y1^带入到损失函数中,也就是:
L o s s = 1 2 ( y 1 − f ( W 21 h o u t − β ) ) 2 + 1 2 ( y 2 − f ( W 22 h o u t − γ ) ) 2 Loss = \frac{1}{2}(y_1-f(W_{21}h_{out}-\beta))^2+\frac{1}{2}(y_2-f(W_{22}h_{out}-\gamma))^2 Loss=21(y1f(W21houtβ))2+21(y2f(W22houtγ))2
如此,便得出损失和 W 21 W_{21} W21之间的代数关系式,接下来只需要对该表达式求导即可得到 Δ W 21 \Delta W_{21} ΔW21以及 Δ β \Delta \beta Δβ

首先, ∂ L o s s ∂ W 21 \frac{\partial Loss}{\partial W_{21}} W21Loss的计算公式为:
∂ L o s s ∂ W 21 = [ y 1 − f ( W 21 h o u t − β ) ] ∗ [ − f ′ ( W 21 h o u t − β ) ] ∗ h o u t = − [ y 1 − f ( W 21 h o u t − β ) ] ∗ f ( W 21 h o u t − β ) [ 1 − ( f ( W 21 h o u t − β ) ) ] ∗ h o u t = − ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) ∗ h o u t \begin{aligned} \frac{\partial Loss}{\partial W_{21}} & = [y_1-f(W_{21}h_{out}-\beta)]*[-f'(W_{21}h_{out}-\beta)]*h_{out} \\ & =- [y_1-f(W_{21}h_{out}-\beta)]*f(W_{21}h_{out}-\beta)[1-(f(W_{21}h_{out}-\beta))]*h_{out} \\ & = -(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1})*h_{out} \end{aligned} W21Loss=[y1f(W21houtβ)][f(W21houtβ)]hout=[y1f(W21houtβ)]f(W21houtβ)[1(f(W21houtβ))]hout=(y1y1^)y1^(1y1^)hout
同样地, ∂ L o s s ∂ β \frac{\partial Loss}{\partial \beta} βLoss的计算公式为:
∂ L o s s ∂ β = [ y 1 − f ( W 21 h o u t − β ) ] ∗ [ − f ′ ( W 21 h o u t − β ) ] ∗ ( − 1 ) = [ y 1 − f ( W 21 h o u t − β ) ] ∗ f ( W 21 h o u t − β ) [ 1 − ( f ( W 21 h o u t − β ) ) ] = ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) \begin{aligned} \frac{\partial Loss}{\partial \beta} & = [y_1-f(W_{21}h_{out}-\beta)]*[-f'(W_{21}h_{out}-\beta)]*(-1) \\ & = [y_1-f(W_{21}h_{out}-\beta)]*f(W_{21}h_{out}-\beta)[1-(f(W_{21}h_{out}-\beta))] \\ & = (y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1}) \end{aligned} βLoss=[y1f(W21houtβ)][f(W21houtβ)](1)=[y1f(W21houtβ)]f(W21houtβ)[1(f(W21houtβ))]=(y1y1^)y1^(1y1^)
由于梯度下降法,需要沿着负梯度方向,所以, Δ W 21 = − ∂ L o s s ∂ W 21 \Delta W_{21}=-\frac{\partial Loss}{\partial W_{21}} ΔW21=W21Loss Δ β = − ∂ L o s s ∂ β \Delta \beta=-\frac{\partial Loss}{\partial \beta} Δβ=βLoss,从而得出 W 21 , β W_{21},\beta W21,β的更新公式为:
W 21 = W 21 + η ∗ Δ W 21 = W 21 − η ∗ ∂ L o s s ∂ W 21 = W 21 + η ∗ ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) ∗ h o u t \begin{aligned} W_{21} &= W_{21} + \eta*\Delta W_{21} \\ & = W_{21}-\eta * \frac{\partial Loss}{\partial W_{21}} \\ & =W_{21}+\eta *(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1})*h_{out} \end{aligned} W21=W21+ηΔW21=W21ηW21Loss=W21+η(y1y1^)y1^(1y1^)hout

β = β + η ∗ Δ β = β − η ∗ ∂ L o s s ∂ β = β − η ∗ ( y 1 − y 1 ^ ) ∗ y 1 ^ ∗ ( 1 − y 1 ^ ) \begin{aligned} \beta & = \beta+\eta*\Delta \beta \\ & = \beta-\eta* \frac{\partial Loss}{\partial \beta} \\ & = \beta-\eta *(y_1-\hat{y_1})*\hat{y_1}*(1-\hat{y_1}) \end{aligned} β=β+ηΔβ=βηβLoss=βη(y1y1^)y1^(1y1^)

使用同样的方式,可以对 W 11 , δ W_{11},\delta W11,δ的梯度公式进行计算和更新。

pytorch中的反向传播

下面举例说明在pytorch中,如何使用反向传播算法来更新权重以及阈值。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义一个复杂的神经网络
class ComplexNet(nn.Module):
    def __init__(self):
        super(ComplexNet, self).__init__()
        self.fc1 = nn.Linear(10, 50)  # 输入大小为10,输出大小为50
        self.fc2 = nn.Linear(50, 20)  # 输入大小为50,输出大小为20
        self.fc3 = nn.Linear(20, 1)   # 输入大小为20,输出大小为1

    def forward(self, x):
        x = F.relu(self.fc1(x))  # 使用ReLU作为激活函数
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建网络实例
model = ComplexNet()

# 定义损失函数
criterion = nn.MSELoss()

# 随机生成一些输入和目标输出数据
input_data = torch.randn((32, 10))  # 32个样本,每个样本特征数为10
target_output = torch.randn((32, 1))  # 对应的32个目标输出

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
model.train()	# 设置模型为训练模式
epochs = 1000
for epoch in range(epochs):
    # 梯度清零
    optimizer.zero_grad()

    # 前向传播
    output = model(input_data)

    # 计算损失
    loss = criterion(output, target_output)

    # 反向传播
    loss.backward()

    # 更新模型参数
    optimizer.step()

    # 每隔一段时间输出一下损失值
    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 打印模型结构
print(model)

pythrch中,输入在流经每一个神经元时,会构建一个动态计算图(与tensorflow不同,tensorflow为静态计算图),记录了每个神经元的输入输出信息。在反向传播时, loss.backward()会根据已知的样本数据以及神经元的输入输出信息,计算连接权重以及阈值的梯度,然后optimizer.step()来实现对权重和阈值的更新。需要注意的是,在每一个mini-batch开始前,需要使用optimizer.zero_grad()对梯度置零。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/765479.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发Ability Kit(程序访问控制):【安全控件概述】

安全控件概述 安全控件是系统提供的一组系统实现的ArkUI组件,应用集成这类组件就可以实现在用户点击后自动授权,而无需弹窗授权。它们可以作为一种“特殊的按钮”融入应用页面,实现用户点击即许可的设计思路。 相较于动态申请权限的方式&am…

【聊聊原子性,中断,以及nodejs中的具体示例】

什么是原子性 从一个例子说起, x ,读和写 , 如图假设多线程,线程1和线程2同时操作变量x,进行x的操作,那么由于写的过程中,都会先读一份x数据到cpu的寄存器中,所以这个时候cpu1 和 c…

【ONLYOFFICE】| 桌面编辑器从0-1使用初体验

目录 一. 🦁 写在前面二. 🦁 在线使用感受2.1 创建 ONLYOFFICE 账号2.2 编辑pdf文档2.3 pdf直接创建表格 三. 🦁 写在最后 一. 🦁 写在前面 所谓桌面编辑器就是一种用于编辑文本、图像、视频等多种自媒体的软件工具,具…

OBS 免费的录屏软件

一、下载 obs 【OBS】OBS Studio 的安装、参数设置和录屏、摄像头使用教程-CSDN博客 二、使用 obs & 输出无黑屏 【OBS任意指定区域录屏的方法-哔哩哔哩】 https://b23.tv/aM0hj8A OBS任意指定区域录屏的方法_哔哩哔哩_bilibili 步骤: 1)获取区域…

Qt源码分析:窗体绘制与响应

作为一套开源跨平台的UI代码库,窗体绘制与响应自然是最为基本的功能。在前面的博文中,已就Qt中的元对象系统(反射机制)、事件循环等基础内容进行了分析,并捎带阐述了窗体响应相关的内容。因此,本文着重分析Qt中窗体绘制相关的内容…

Vue3快速上手--3小时掌握

1. Vue3简介 2020年9月18日,Vue.js发布版3.0版本,代号:One Piece(n经历了:4800次提交、40个RFC、600次PR、300贡献者官方发版地址:Release v3.0.0 One Piece vuejs/core截止2023年10月,最新的…

阿里Nacos下载、安装(保姆篇)

文章目录 Nacos下载版本选择Nacos安装Windows常见问题解决 更多相关内容可查看 Nacos下载 Nacos官方下载地址:https://github.com/alibaba/nacos/releases 码云拉取(如果国外较慢或者拉取超时可以试一下国内地址) //国外 git clone https:…

数组-长度最小的子数组

M长度最小的子数组(leetcode209) /*** param {number} target* param {number[]} nums* return {number}*/ var minSubArrayLen function(target, nums) {const n nums.length;let ans n 1;let sum 0; // 子数组元素和let left 0; // 子数组…

美团实习—后端开发凉经

面试经历分享 日期: 4月22日时长: 50分钟 意外之喜 没想到在面试过程中,我再次被选中进行下一轮,这确实让我感到有些意外和欣喜。这次面试经历对我而言,不仅是一次技能的检验,更是一次知…

MySQL:设计数据库与操作

设计数据库 1. 数据建模1.1 概念模型1.2 逻辑模型1.3 实体模型主键外键外键约束 2. 标准化2.1 第一范式2.2 链接表2.3 第二范式2.4 第三范式 3. 数据库模型修改3.1 模型的正向工程3.2 同步数据库模型3.3 模型的逆向工程3.4 实际应用建议 4. 数据库实体模型4.1 创建和删除数据库…

10.8K star!史上最强Web应用防火墙雷池WAF

长亭雷池SafeLine是长亭科技耗时近 10 年倾情打造的WAF(Web Application Firewall), 一款敢打出口号 “不让黑客越雷池一步” 的 WAF,愿称之为史上最强的一款Web应用防火墙,足够简单、足够好用、足够强的免费且开源的 WAF,基于业…

leetcode-20-回溯-切割、子集

一、[131]分割回文串 给定一个字符串 s,将 s 分割成一些子串,使每个子串都是回文串。 返回 s 所有可能的分割方案。 示例: 输入: "aab" 输出: [ ["aa","b"], ["a","a","b"] ] 分析&…

JAVA连接FastGPT实现流式请求SSE效果

FastGPT 是一个基于 LLM 大语言模型的知识库问答系统,提供开箱即用的数据处理、模型调用等能力。同时可以通过 Flow 可视化进行工作流编排,从而实现复杂的问答场景! 一、先看效果 真正实流式请求,SSE效果,SSE解释&am…

一切为了安全丨2024中国应急(消防)品牌巡展武汉站成功召开!

消防品牌巡展武汉站 6月28日,由中国安全产业协会指导,中国安全产业协会应急创新分会、应急救援产业网联合主办,湖北消防协会协办的“一切为了安全”2024年中国应急(消防)品牌巡展-武汉站成功举办。该巡展旨在展示中国应急(消防&am…

Python基础002

Python数据类型 1、字符串&#xff08;str&#xff09; str3 """I miss you so much""" print("str3 ", str3,type(str3)) str3 I miss you so much <class str>2、整数&#xff08;int&#xff09; str1 55 print(&quo…

【面试题】TLS和SSL协议的区别

TLS&#xff08;Transport Layer Security&#xff09;和SSL&#xff08;Secure Sockets Layer&#xff09;协议都是用于在网络上建立安全通信连接的协议&#xff0c;但它们在多个方面存在区别。以下是TLS和SSL协议之间区别的详细分析&#xff1a; 1. 发展历程与标准化 SSL&a…

如何找BMS算法、BMS软件的实习

之前一直忙&#xff0c;好久没有更新了&#xff0c;今天就来写一篇文章来介绍如何找BMS方向的实习&#xff0c;以及需要具备哪些条件&#xff0c;我的实习经历都是在读研阶段找的&#xff0c;读研期间两段的实习经历再加上最高影响因子9.4分的论文&#xff0c;我的秋招可以说是…

分子AI预测赛Task2笔记

下面所述比较官方的内容都来自官方文档 ‍‌⁠‌‍​​​‌​​⁠​​​​​&#xfeff;​​​&#xfeff;‍‬​​‍⁠‍‍​​‬​&#xfeff;‌​​​‌‍‬​​​​​​‍‌Task2&#xff1a;赛题深入解析 - 飞书云文档 (feishu.cn) 赛题背景 强调了人工智能在科研领域&…

探囊取物之多形式注册页面(基于BootStrap4)

基于BootStrap4的注册页面&#xff0c;支持手机验证码注册、账号密码注册 低配置云服务器&#xff0c;首次加载速度较慢&#xff0c;请耐心等候&#xff1b;演练页面可点击查看源码 预览页面&#xff1a;http://www.daelui.com/#/tigerlair/saas/preview/ly4gax38ub9j 演练页…

晚上睡觉要不要关路由器?一语中的

前言 前几天小白去了一个朋友家&#xff0c;有朋友说&#xff1a;路由器不关机的话会影响睡眠吗&#xff1f; 这个影响睡眠嘛&#xff0c;确实是会的。毕竟一时冲浪一时爽&#xff0c;一直冲浪一直爽……刷剧刷抖音刷到根本停不下来&#xff0c;肯定影响睡眠。 所以晚上睡觉要…