神经网络基础--什么是正向传播??什么是方向传播??

前言

  • 本专栏更新神经网络的一些基础知识;
  • 这个是本人初学神经网络做的笔记,仅仅堆正向传播、方向传播就行了了一个讲解,更加系统的讲解,本人后面会更新《李沐动手学习深度学习》,会更有详细讲解;
  • 案例代码基于pytorch;
  • 欢迎收藏 + 关注, 本人将会持续更新。

文章目录

  • 正向传播与反向传播
    • 梯度下降法
      • 简介
      • 不同梯度下降法区别
    • 前向传播
    • 反向传播算法
      • 简介
      • 案例介绍原理
      • 代码展示

正向传播与反向传播

梯度下降法

简介

表达式

w i j n e w = w i j o l d − η ∂ E ∂ w i j w_{ij}^{new}= w_{ij}^{old} - \eta \frac{\partial E}{\partial w_{ij}} wijnew=wijoldηwijE

其中 n \text{n} n 是学习率,控制梯度收敛的快慢。

深度学习几个基础概念

  1. Eopch:使用数据对模型进行完整训练,一次
  2. Batch:使用训练集中小部分样本对模型权重进行方向传播更新
  3. Iteration:使用一个 Batch 数据对模型进行一次参数更新

不同梯度下降法区别

梯度下降方式Training Set SizeBatch SizeNumber of Batches
BGD(批量梯度下降)NN1
SGD(随机梯度下降)N1N
Mini-Batch(小批量梯度下降)NBN / B + 1
  • 批量梯度下降法是最原始的形式,它是指在每一次迭代时使用所有样本来进行梯度的更新;
  • 随机梯度下降法不同于批量梯度下降,随机梯度下降是在每次迭代时使用一个样本来对参数进行更新;
  • 小批量梯度下降相当于是前两个的总和。
  • 具体优缺点:后面更新**李沐老师《动手学习深度学习》**会有更详细解释。

前向传播

前向传播就是输入x,在神经网络中一直向前计算,一直到输出层为止,图像如下:

在这里插入图片描述

在网络的训练过程中经过前向传播后得到的最终结果跟训练样本的真实值总是存在一定误差,这个误差便是损失函数。想要减小这个误差,就用损失函数Loss,从后往前,依次求各个参数的偏导,这就是反向传播.

反向传播算法

简介

BP 算法也叫做误差反向传播算法,它用于求解模型的参数梯度,从而使用梯度下降法来更新网络参数。它的基本工作流程如下:

  1. 通过正向传播得到误差,正向传播指的是数据从输入–> 隐藏层–>输出层,经过层层计算得到预测值,并利用损失函数得到预测值和真实值之前的误差。
  2. 通过反向传播把误差传递给模型的参数,从而调整神经网络参数,缩小预测值和真实值之间的误差。
  3. 反向传播算法是利用链式法则进行梯度求解,然后进行参数更新

案例介绍原理

  • 网络结构:
    • 输入层:两个神经元
    • 隐藏层:一共两层,每一层两个神经元
    • 输出层:输出两个值
  • 输入:
    • i1:0.05,i2:0.10
  • 目标值:
    • 0.01,0.99
  • 初始化权重:
    • w1: 0.15
    • w2: 0.20
    • w3: 0.25
    • w4:0.30
    • w5:0.40
    • w6:0.45
    • w7:0.50
    • w8:0.55

在这里插入图片描述

  1. 由下向上看,最下层绿色的两个圆代表两个输入值
  2. 右侧的8个数字,最下面4个表示 w1、w2、w3、w4 的参数初始值,最上面的4个数字表示 w5、w6、w7、w8 的参数初始值
  3. b1 值为 0.35,b2 值为 0.60
  4. 预测结果分别为: 0.7514、0.7729

接下来就是梯度更新

方向传播中梯度跟新流程

w5、w7为例:

在这里插入图片描述

计算出偏导后,运用梯度下降法进行更新,这里学习率为0.5:

在这里插入图片描述

接下来跟新w1:

在这里插入图片描述

梯度更新:

在这里插入图片描述

代码展示

该代码就是将上面的神经网络代码实现,牢记对神经网络理解很有用

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # 定义神经网络结构
        self.linear1 = nn.Linear(2, 2)
        self.linear2 = nn.Linear(2, 2)
        
        # 神经网络参数初始化,设置权重和偏置
        self.linear1.weight.data = torch.tensor([[0.15, 0.20], [0.25, 0.30]])  # 权重
        self.linear2.weight.data = torch.tensor([[0.40, 0.45], [0.50, 0.55]])
        self.linear1.bias.data = torch.tensor([0.35, 0.35])   # 偏置
        self.linear2.bias.data = torch.tensor([0.60, 0.60])
        
    def forward(self, x):
        
        x = self.linear1(x)     # 经过第一次线性变换
        x = torch.sigmoid(x)    # 经过激活函数,进行非线性变换
        x = self.linear2(x)     # 经过第二层线性变换
        x = torch.sigmoid(x)    # 经过激活函数,进行非线性变换
        
        return x
        
if __name__ == '__main__':
    
    # 输入变量和目标变量
    inputs = torch.tensor([0.05, 0.10])
    target = torch.tensor([0.01, 0.99])
    
    # 创建神经网络
    net = Net()
    # 训练
    outputs = net(inputs)
    
    # 计算误差,梯度下降法,*********  公式实现  ***********
    loss = torch.sum((target - outputs) ** 2) / 2
    
    # 计算梯度下降
    optimizer = optim.SGD(net.parameters(), lr=0.5)
    
    # 清除梯度
    optimizer.zero_grad()
    
    # 反向传播
    loss.backward()
    
    # 打印 w5、w7、w1 的梯度值
    print(net.linear1.weight.grad.data)
    print(net.linear2.weight.grad.data)
    
    # 更新权重,梯度下降更新
    optimizer.step()
    
    # 打印网络参数
    print(net.state_dict())

输出:

tensor([[0.0004, 0.0009],
        [0.0005, 0.0010]])
tensor([[ 0.0822,  0.0827],
        [-0.0226, -0.0227]])
OrderedDict([('linear1.weight', tensor([[0.1498, 0.1996],
        [0.2498, 0.2995]])), ('linear1.bias', tensor([0.3456, 0.3450])), ('linear2.weight', tensor([[0.3589, 0.4087],
        [0.5113, 0.5614]])), ('linear2.bias', tensor([0.5308, 0.6190]))])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/911083.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第三十七天 | 完全背包 518.零钱兑换 Ⅱ 377.组合总和Ⅳ 70.爬楼梯(进阶版)

完全背包: 文章链接 题目链接:卡码网 52.携带研究材料 与01背包的区别在于物品数量无限,因此同一种物品可以取多次。 递推式如下: 二维:dp[i][j] max(dp[i - 1][j], dp[i][j - weights[i]] value[i]),因…

C语言心型代码解析

方法一 心型极坐标方程 爱心代码你真的理解吗 笛卡尔的心型公式&#xff1a; for (y 1.5; y > -1.5; y - 0.1) for (x -1.5; x < 1.5; x 0.05) 代码里面用了二个for循环&#xff0c;第一个代表y轴&#xff0c;第二个代表x轴 二个增加的单位不同&#xff0c;能使得…

C语言网络编程 -- TCP/iP协议

一、Socket简介 1.1 什么是socket socket通常也称作"套接字"&#xff0c;⽤于描述IP地址和端⼝&#xff0c;是⼀个通信链的句柄&#xff0c;应⽤ 程序通常通过"套接字"向⽹络发出请求或者应答⽹络请求。⽹络通信就是两个进程 间的通信&#xff0c;这两个进…

字符串接龙 /单词接龙 (BFs C#

卡码网 110和 力扣127 和LCq 108题都是一个解法 这两道题乍一看在结果处可能不一样 力扣要求 字符串里边必须包含对应的最后一个字符 而110不需要最后一个字符 但是在实验逻辑上是一致的 只是110需要把如果在set中找不到最后一个字符就直接返回0的逻辑删去 就可以了 这就是…

Transformer和BERT的区别

Transformer和BERT的区别比较表&#xff1a; 两者的位置编码&#xff1a; 为什么要对位置进行编码&#xff1f; Attention提取特征的时候&#xff0c;可以获取全局每个词对之间的关系&#xff0c;但是并没有显式保留时序信息&#xff0c;或者说位置信息。就算打乱序列中token…

python操作MySQL以及SQL综合案例

1.基础使用 学习目标&#xff1a;掌握python执行SQL语句操作MySQL数据库软件 打开cmd下载安装 安装成功 connection就是一个类&#xff0c;conn类对象。 因为位置不知道&#xff0c;所以使用关键字传参。 表明我们可以正常连接到MySQL 演示、执行非查询性质的SQL语句 pytho…

【报告PDF附下载】2024人工智能大模型技术财务应用蓝皮书

《人工智能大模型技术财务应用蓝皮书》 是一本探讨AI大模型技术在财务管理领域应用的权威指南。书中不仅概述了人工智能大模型技术的发展历程、典型特征和未来趋势&#xff0c;还详细介绍了它的体系架构和在财务领域的应用情况。 书中通过家用电器制造、银行、汽车企业、基础设…

快速上手vue3+js+Node.js

安装Navicat Premium Navicat Premium 创建一个空的文件夹&#xff08;用于配置node&#xff09; 生成pakeage.json文件 npm init -y 操作mysql npm i mysql2.18.1 安装express搭建web服务器 npm i express4.17.1安装cors解决跨域问题 npm i cors2.8.5创建app.js con…

【Python爬虫实战】DrissionPage 与 ChromiumPage:高效网页自动化与数据抓取的双利器

&#x1f308;个人主页&#xff1a;易辰君-CSDN博客 &#x1f525; 系列专栏&#xff1a;https://blog.csdn.net/2401_86688088/category_12797772.html ​ 目录 前言 一、DrissionPage简介 &#xff08;一&#xff09;特点 &#xff08;二&#xff09;安装 &#xff08;三…

【JAVA】java 企业微信信息推送

前言 JAVA中 将信息 推送到企业微信 // 企微消息推送messageprivate String getMessage(String name, String problemType, String pushResults, Long orderId,java.util.Date submitTime, java.util.Date payTime) {String message "对接方&#xff1a;<font color\…

前端md5加密

npm下载 npm install --save ts-md5页面引入 import { Md5 } from ts-md5使用 const md5PwdMd5.hashStr("123456")md5Pwd&#xff08;加密后的数据&#xff09; .toUpperCase()方法转大写

DDRSYS,不同频点的时序参数配置说明,DBI/DM功能说明

文章目录 不同频点的时序参数配置说明LPDDR4 时序参数DFI 参数对应配置DDR3/4DBI功能说明&#xff0c;MC控制DBI情况 不同频点的时序参数配置说明 LPDDR4 时序参数 LP4的时序参数从JEDEC颗粒文档可以检索到读写的时序参数如下&#xff1a; 此图主要关注不同频点对应的RL和WL…

如何自学机器学习?

自学机器学习可以按照以下步骤进行&#xff1a; 一、基础知识准备 数学基础&#xff1a; 高等数学&#xff1a;学习微积分&#xff08;包括导数、微分、积分等&#xff09;、极限、级数等基本概念。这些知识是后续学习算法和优化方法的基础。 线性代数&#xff1a;掌握矩阵…

工程巡查应该怎么做?如何利用巡查管理软件?

工程行业&#xff0c;无论是建设单位&#xff0c;监理单位&#xff0c;还是施工单位&#xff0c;工程巡查几乎是每日必做的工作。然而&#xff0c;巡查过程中&#xff0c;传统的做法通常依赖手动记录、拍照上传、在微信群中进行汇报。这种方式需要建大量的微信群&#xff0c;不…

Scala入门基础(16)scala的包

Scala的包定义包定义包对象Scala的包的导入导入重命名 一.Scala的包 package&#xff08;包&#xff1a;一个容器。可以把类&#xff0c;对象&#xff0c;包&#xff0c;装入。 好处&#xff1a; 区分同名的类&#xff1b;类很多时&#xff0c;更好地管理类&#xff1b;控制…

协程6 --- HOOK

文章目录 HOOK 概述链接运行时动态链接 linux上的常见HOOK方式修改函数指针用户态动态库拦截getpidmalloc 第一版malloc 第二版malloc/free通过指针获取到空间大小malloc 第三版strncmp 内核态系统调用拦截堆栈式文件系统 协程的HOOK HOOK 概述 原理&#xff1a;修改符号指向 …

MySQL中,GROUP BY 分组函数

文章目录 示例查询&#xff1a;按性别分组统计每组信息示例查询&#xff1a;按性别分组显示详细信息示例查询&#xff1a;按性别分组并计算平均年龄,如果你还想统计每个性别的平均年龄&#xff0c;可以结合AVG()函数&#xff1a;说明 示例查询&#xff1a;按性别分组统计每组信…

免费数据集网站

1、DataSearch https://datasetsearch.research.google.comhttp://DataSearch 2、FindData findata-科学数据搜索引擎https://www.findata.cn/ 3、Kaggle Kaggle: Your Machine Learning and Data Science CommunityKaggle is the world’s largest data science community …

十二:java web(4)-- Spring核心基础

目录 创建项目 Spring 核心基础 Spring 容器 Spring 容器的作用 Spring 容器的工作流程 Bean Bean 的生命周期 IOC&#xff08;控制反转&#xff09;与依赖注入&#xff08;DI&#xff09; 控制反转的概念 依赖注入的几种方式&#xff08;构造器注入、Setter 注入、接…

MybatisPlus入门(八)MybatisPlus-DQL编程控制

一、字段映射与表名映射 数据库表和实体类名称一样自动关联&#xff0c;数据库表和实体类有部分情况不一样。 问题一&#xff1a;表名与编码开发设计不同步&#xff0c;表名和实体类名称不一致。 解决办法&#xff1a; 在模型类上方&#xff0c;使用TableName注解&#xf…