[PyTorch][chapter 8][李宏毅深度学习][Back propagation]

前言:

              反向传播算法(英:Backpropagation algorithm,简称:BP算法)是一种监督学习算法,常被用来训练多层感知机。 它用于计算梯度计算中,降低误差。

      

目录:

  1.     链式法则
  2.     模型简介(Model)
  3.     损失函数,梯度
  4.     手写例子
  5.     min-batch

一  链式法则

      链式法则是反向传播算法里面的核心。

     case1: y=g(x),z=h(y), x,y,z 都是scalar

                       

                     \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}        

      case2:  x=g(s),y=h(s),z=k(x,y),s,x,y,z 都是scalar

                   

                       \frac{dz}{ds}=\frac{dz}{dy}\frac{dy}{ds}+\frac{dz}{dx}\frac{dx}{ds}

      case3:   x,y,z 都是向量vector

                   x\rightarrow y\rightarrow z

                    \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}


二  模型(Model)

以常用的网络模型DNN 为例:

 激活函数为 \sigma

 总的层数为 L


三    损失函数,梯度

       3.1 损失函数

           J(w,b)=||a^{L}-y||_2^{2}

       3.2 梯度更新

               梯度计算分为两步:

   Forward pass, Backward pass

         a Forward pass

               假设 \delta^{l}=\frac{\partial J}{\partial z^l}:

            利用微分和迹的关系很容易得到

         

          b  Backward pass  

               假设为最后一层L

                 \delta^{L}=(\frac{\partial a^L}{\partial z^L})^T\frac{\partial J}{\partial a^L}

                       =diag(\sigma^{'}(z^{L}))(a^{L}-\hat{y})

                      =(a^{L}-\hat{y})\odot \sigma{'}(z^{L})

            我们用数学归纳法,第L层的\delta^{L}已经求出, 假设第l+1层的\delta^{l+1}已经求出来了,那么我们如何求出第l层的\delta^{l}呢?

                \delta^{l}=\frac{\partial J}{\partial z^{l}}

                    =(\frac{\partial z^{l+1}}{\partial z^{l}})^T\frac{\partial J}{\partial z^{l+1}}

                    =(\frac{\partial z^{l+1}}{\partial a^l}\frac{\partial a^{l}}{\partial z^l})^T \delta^{l+1}

                    =(diag(\sigma^{'}(z^l)(w^{l+1})^T)\delta^{l+1}

                    =(w^{l+1})^T\delta^{t+1}\odot \sigma^{'}(z^l)


四   简单DNN 网络例子

 4.1 说明:

          这里面随机生成5张图形,分别对应手写数字1,2,3,4,5。

简单的了解一下如何快速搭建一个DNN Model, 梯度如何计算,更新的.

 

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 17:21:35 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch import optim


class DNN(nn.Module):
    
    '''
    它是一个序列容器,是nn.Module的子类。 
    `nn.Sequential` 中的层是有顺序的,而且严格按照其顺序执行
    相邻两个层连接必须保证前一个层的输出与后一个层的输入相匹配。
    '''
    def __init__(self):
        
        super(DNN, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=500),
            nn.Sigmoid(),
            nn.Linear(in_features=500, out_features=10),
            nn.Sigmoid()
            )

    def forward(self, input):
        
        output = self.net(input)
        
        return output


def train():
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model = DNN()
    criteon = torch.nn.CrossEntropyLoss(reduction='mean')
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    batch_size= 5
    data = torch.rand((batch_size,28*28))
    epochs = 2
    target = torch.tensor([0,1,2,3,4])
    target = target.to(device)
    
    for epoch in range(epochs):
        
        yHat = model(data)
        loss = criteon(yHat, target)
        loss.backward()
        print("\n loss ",loss)
        
        optimizer.step()
        

if __name__ == "__main__":
    train()
    
    
    

 


五  min-batch

  在深度学习训练中,数据集我们通常采用min-batch 方案

    我们采用随机梯度方法,是为了加快运算速度。

但是GPU 可以并行运算,所以可以采用min-batch 方法进行梯度计算。

   使用min-batch 有个限制:

    1: 硬件限制 batch 不能超过硬件大小

    2:    batch 不能太大,否则容易陷入到局部极小值点,采用小的batch 可以有一定的随机性

每次出发点都不一样,一定概率跳过局部极小值点

参考:

7: Backpropagation_哔哩哔哩_bilibili

https://www.cnblogs.com/pinard/p/6422831.html

CSDN

8-1: “Hello world” of deep learning_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/258665.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【YOLOv8量化】普通CPU上加速推理可达100+FPS

NNCF介绍 OpenVINO2023版本衍生出了一个新支持工具包NNCF(Neural Network Compression Framework – 神经网络压缩框架),通过对OpenVINO IR格式模型的压缩与量化更好的提升模型在OpenVINO框架上部署的推理性能,github。 https://github.com/openvinoto…

银河麒麟v10 安装mysql 8.35

银河麒麟v10 安装mysql 8.35 1、下载Mysql安装包2、安装Mysql 8.352.1、安装依赖包2.2、安装Mysql2.3、安装后配置 1、下载Mysql安装包 访问官网下载链接 链接: https://dev.mysql.com/downloads/mysql/ 选择如下 点击下载按钮 下载安装包 2、安装Mysql 8.35 官方安装文档…

AWS Linux安装桌面并远程访问

文章目录 小结问题及解决参考 小结 在AWS Linux安装了桌面并进行远程访问。 问题及解决 需要使用过程桌面访问AWS Linux,这里在AWS服务器安装并使用Amazon Linux 2 MATE desktop。 检查OS版本: [ec2-userip-10-0-3-241 ~]$ grep PRETTY_NAME /etc/o…

算法设计与分析2023秋-头歌实验-实验七 动态规划

文章目录 第1关:数塔问题任务描述相关知识编程要求解题思路测试说明参考答案 第2关:最长公共子序列任务描述相关知识编程要求解题思路:测试说明参考答案 第3关:求序列-2 11 -4 13 -5 -2的最大子段和任务描述相关知识编程要求解题思…

可实现RSSD云硬盘120万IOPS的SPDK IO路径优化

一. 简介 用户对超高并发、超大规模计算等需求推动了存储硬件技术的不断发展,存储集群的性能越来越好,延时也越来越低,对整体IO路径的性能要求也越来越高。在云硬盘场景中,IO请求从生成到后端的存储集群再到返回之间的IO路径比较…

less基础介绍

什么是less? Less是一个C5S预处理器,Less文件后缀是,less。扩充了 CSS 语言,使CSS 具备一定的逻辑性、计算能力 注意事项: 浏览器不识别 Less 代码,目前阶段,网页要引入对应的 CSS 文件 VS Code 插件: Easy LESS,保存 …

JVM快速入门

JVM 字节码 字节码文件的组成 字节码由五个部分组成:基础信息 常量池 字段 方法 属性 基础信息: 魔数、字节码文件对应的版本号、访问标识(public final)、该类的父类索引、该类实现哪些接口的索引 魔数:文件无法…

鸿蒙 - arkTs:渲染(循环 - ForEach,判断 - if)

ForEach循环渲染: 参数: 要循环遍历的数组,Array类型遍历的回调方法,Function类型为每一项生成唯一标识符的方法,有默认生成方法,非必传 使用示例: interface Item {name: String,price: N…

ValueError: source code string cannot contain null bytes

解决:把存在这个问题的包,全部卸载重装 pip uninstall xxx pip install xxx

美国联邦机动车安全标准-FMVSS

FMVSS标准介绍: FMVSS是美国《联邦机动车安全标准》,由美国运输部下属的国家公路交通安全管理局(简称NHTSA)具体负责制定并实施。是美国联邦政府针对机动车制定的安全标准,旨在提高机动车的安全性能,减少交通事故中的人员伤亡。F…

路由跳转传递参数注意事项,查询字符串传参,params传参需要注意的地方,菜单内容的二级内容 vue3

路由跳转和传参(vue3)_vue3路由传参-CSDN博客 注意: import {useRouter} from "vue-router"const routeruseRouter()1.查询字符串传参,传一个对象,对象里面可以写path字段 router.push({path:/item,query:{id:1}} ) 通过当前路由…

(保姆级教程)一篇文章,搞定所有Linux命令,以及tar解压缩命令,wget、rpm等下载安装命令,Linux的目录结构,以及用户和用户组

文章目录 Linux命令1. Linux目录结构2. 基本命令(了解)3. 目录(文件夹)命令列出目录切换目录创建目录删除目录复制目录移动和重命名目录 4. 文件命令创建文件编辑文件编辑文件时的其他操作 查看文件移动/重命名文件复制文件删除文…

如何通过ETLCloud的API对接功能实现各种SaaS平台数据对接

前言 当前使用SaaS系统的企业越来越多,当我们需要对SaaS系统中产生的数据进行分析和对接时就需要与SaaS系统提供的API进行对接,因为SaaS一般是不会提供数据库表给企业,这时就应该使用ETL(Extract, Transform, Load)的…

Jmeter接口程序项目实战教程

1.什么是jmeter? JMeter是100%完全由Java语言编写的,免费的开源软件,是非常优秀的性能测试和接口测试工具,支持主流协议的测试 2.jmeter能做什么? JMeter是100%完全由Java语言编写的软件性能测试的GUI的测试工具&am…

车载蓝牙物联网解决方案

车载蓝牙物联网解决方案是一种基于蓝牙技术,结合物联网技术的智能车载系统。它利用蓝牙技术将智能手机、智能手表、智能车载设备等连接起来,实现设备之间的无缝通信和数据共享,为驾驶者提供更加便捷、安全和智能的驾驶体验。 车载蓝牙物联网解…

【3D数据读取】利用JAVA读取GLB(GLTF)文件数据

了解GLB和GLTF: GLB和GLTF是用于共享3D数据的标准化文件格式。GLB是GLTF的二进制格式,而GLTF基于JSON,一种基于文本的数据格式。 GLB文件: 由一个头部和一个二进制数据块组成。头部包含文件的元数据,例如文件版本、文件…

网络时代的新宠

当今社会,随着科技的不断进步和互联网的普及,手机已经成为了人们生活中不可或缺的一部分。它不仅仅是一个通信工具,更是娱乐、学习和获取信息的利器。而其中,手机无人直播更是近年来备受关注的热门话题。 直播,一种实…

hive 用户自定义函数udf,udaf,udtf

udf:一对一的关系 udtf:一对多的关系 udaf:多对一的关系 使用Java实现步骤 自定义编写UDF函数注意: 1.需要继承org.apache.hadoop.hive.ql.exec.UDF 2.需要实现evaluete函数 编写UDTF函数注意: 1.需要继承org.apache…

【MongoDB】--MongoDB的Sort排序问题

目录 一、问题背景描述1.1、问题背景1.2、问题分析 二、建立索引支持深度翻页查询2.1、调整sort排序的内存限制【不建议】2.2、创建索引2.3、拓展--组合索引什么时候失效 二、聚合查询解决深度翻页查询 一、问题背景描述 1.1、问题背景 现实系统页面翻页到20000页之后&#x…

MQTT直连接入

本文介如绍何使用MQTT协议,将设备直连到平台内置的MQTT服务。 操作步骤 创建产品 物联网->设备管理->选择产品,填写产品基础信息。 参数 对应设备侧参数 ID 产品唯一标识,若不填写,系统将自动生成唯一ID 设备类型 直…