神经网络:卷积神经网络中的BatchNorm

在这里插入图片描述

一、BN介绍

1.原理

在机器学习中让输入的数据之间相关性越少越好,最好输入的每个样本都是均值为0方差为1。在输入神经网络之前可以对数据进行处理让数据消除共线性,但是这样的话输入层的激活层看到的是一个分布良好的数据,但是较深的激活层看到的的分布就没那么完美了,分布将变化的很严重。这样会使得训练神经网络变得更加困难。所以添加BatchNorm层,在训练的时候BN层使用batch来估计数据的均值和方差,然后用均值和方差来标准化这个batch的数据,并且随着不同的batch经过网络,均值和方差都在做累计平均。在测试的时候就直接作为标准化的依据。

这样的方法也有可能导致降低神经网络的表示能力,因为某些层的全局最优的特征可能不是均值为0或者方差为1的。所以BN层也是能够进行学习每个特征维度的缩放gamma和平移beta的来避免这样的情况。

2.BN层前向传播

def batchnorm_forward(x, gamma, beta, bn_param):
    """先进行标准化再进行平移缩放
    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var

    Input:
    - x: (N, D) 输入的数据
    - gamma: (D,) 每个特征维度数据的缩放
    - beta: (D,) 每个特征维度数据的偏移
    - bn_param: 字典,有如下键值:
       - mode: 'train'/'test' 必须指定
       - eps: 一个常量为了维持数值稳定,保证不会除0
       - momentum: 动量
       - running_mean: (D,) 积累的均值
       - running_var: (D,) 积累的方差

    Returns:
    - out: (N,D)
    - cache: 反向传播时需要的数据
    """
    mode = bn_param['mode']
    eps = bn_param.get('eps', 1e-5)
    momentum = bn_param.get('momentum', 0.9)

    N, D = x.shape
    running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))

    out, cache = None, None
    if mode == 'train':
        sample_mean = np.mean(x, axis=0)
        sample_var = np.var(x, axis=0)
        # 先标准化
        x_hat = (x - sample_mean)/(np.sqrt(sample_var + eps))
        # 再做缩放偏移
        out = gamma * x_hat + beta
        cache = (gamma, x, sample_mean, sample_var, eps, x_hat)
        running_mean = momentum * running_mean + (1-momuntum)*sample_mean
        running_var = momentum * running_var + (1-momentum)*sample_var
    elif mode == 'test':
        # 先标准化
        #x_hat = (x - running_mean)/(np.sqrt(running_var+eps))
        # 再做缩放偏移
        #out = gamma * x_hat + beta
        # 或者是下面的骚写法
        scale = gamma/(np.sqrt(running_var + eps))
        out = x*scale + (beta - running_mean*scale)
    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
    
    bn_param['running_mean'] = running_mean
    bn_param['running_var'] = running_var

    return out, cache

3.BN层反向传播

def batchnorm_barckward(out, cache):
    """反向传播的简单写法,易于理解
    Inputs:
    - dout: (N,D) dloss/dout
    - cache: (gamma, x, sample_mean, sample_var, eps, x_hat)

    Returns:
    - dx: (N,D)
    - dgamma: (D,) 每个维度的缩放和平移参数不同
    - dbeta: (D,)
    """
    dx, dgamma, dbeta = None, None, None
    # unpack cache
    gamma, x, u_b, sigma_squared_b, eps, x_hat = cache
    N = x.shape[0]

    dx_1 = gamma * dout # dloss/dx_hat = dloss/dout * gamma (N, D)
    dx_2_b = np.sum((x - u_b) * dx_1, axis=0)
    dx_2_a = ((sigma_squared_b + eps)**-0.5)*dx_1
    dx_3_b = (-0.5) * ((sigma_squared_b + eps)**-1.5)*dx_2_b
    dx_4_b = dx_3_b * 1
    dx_5_b = np.ones_like(x)/N * dx_4_b
    dx_6_b = 2*(x-u_b)*dx_5_b
    dx_7_a = dx_6_b*1 + dx_2_a*1
    dx_7_b = dx_6_b*1 * dx_2_a*1
    dx_8_b = -1*np.sum(dx_7_b, axis=0)
    dx_9_b = np.ones_like(x)/N * dx_8_b
    dx_10 = dx_9_b + dx_7_a

    dgamma = np.sum(x_hat * dout, axis=0)
    dbeta = np.sum(dout, axis=0)
    dx = dx_10

    return dx, dgamma, dbeta

下面是直接使用公式来计算:

def batchnorm_backward_alt(dout, cache):
    dx, dgamma, dbeta = None, None, None
    # unpack cache
    gamma, x, u_b, sigma_squared_b, eps, x_hat = cache
    N = x.shape[0]
    dx_hat = dout * gamma
    dvar = np.sum(dx_hat* (x - sample_mean) * -0.5 * np.power(sample_var + eps, -1.5), axis = 0)

    dmean = np.sum(dx_hat * -1 / np.sqrt(sample_var +eps), axis = 0) + dvar * np.mean(-2 * (x - sample_mean), axis =0)

    dx = 1 / np.sqrt(sample_var + eps) * dx_hat + dvar * 2.0 / N * (x-sample_mean) + 1.0 / N * dmean

    dgamma = np.sum(x_hat * dout, axis = 0)
    dbeta = np.sum(dout , axis = 0)

    return dx, dgamma, dbeta

4.BN有什么作用

  1. 对于不好的权重初始化有更高的鲁棒性,仍然能得到较好的效果。
  2. 能更好的避免过拟合。
  3. 解决梯度消失/爆炸问题,BN防止了前向传播的时候数值过大或者过小,这样就能让反向传播时梯度处于一个较好的区间内。

二、卷积神经网络中的BN

1.前向传播

def spatial_batchnorm_forward(x, gamma, beta, bn_param):
    """利用普通神经网络的BN来实现卷积神经网络的BN
    Inputs:
    - x: (N, C, H, W)
    - gamma: (C,)缩放系数
    - beta: (C,)平移系数
    - bn_param: 包含如下键的字典
       - mode: 'train'/'test'必须的键
       - eps: 数值稳定需要的一个较小的值
       - momentum: 一个常量,用来处理running mean和var的。如果momentum=0 那么之前不利用之前的均值和方差。momentum=1表示不利用现在的均值和方差,一般设置momentum=0.9
       - running_mean: (C,)
       - running_var: (C,)

    Returns:
    - out: (N, C, H, W)
    - cache: 反向传播需要的数据,这里直接使用了普通神经网络的cache
    """
    N, C, H, W = x.shape
    # transpose之后(N, W, H, C) channel在这里就可以看成是特征
    temp_out, cache = batchnorm_forward(x.transpose(0, 3, 2, 1).reshape((N*H*W, C)), gamma, beta, bn_param)
    # 再恢复shape
    out = temp_output.reshape(N, W, H, C).transpose(0, 3, 2, 1)
    return out, cache

2.反向传播

def spatial_batchnorm_backward(dout, cache):
    """利用普通神经网络的BN反向传播实现卷积神经网络中的BN反向传播
    Inputs:
    - dout: (N, C, H, W) 反向传播回来的导数
    - cache: 前向传播时的中间数据

    Returns:
    - dx: (N, C, H, W)
    - dgamma: (C,) 缩放系数的导数
    - dbeta: (C,) 偏移系数的导数
    """
    dx, dgamma, dbeta = None, None, None
    N, C, H, W = dout.shape
    # 利用普通神经网络的BN进行计算 (N*H*W, C)channel看成是特征维度
    dx_temp, dgamma, dbeta = batchnorm_backward_alt(dout.transpose(0, 3, 2, 1).reshape((N*H*W, C)), cache)
    # 将shape恢复
    dx = dx_temp.reshape(N, W, H, C).transpose(0, 3, 2, 1)
    return dx, dgamma, dbeta

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/386636.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java 集合

一、集合的框架体系(重要,背!!!) 1.Collection(单列集合) 2.Map(双列集合) 二、Collection接口 1.特点 使用了Collection接口的子类 可以存放多个元素&am…

C语言学习day12:for循环

前面学了dowhile循环&#xff0c;今天我们来学习经常用到的for循环&#xff1a; for循环&#xff1a; 例子&#xff1a; int main() {//int i;for (int i 0; i < 10;i) {printf("%d\n",i);};system("pause");return EXIT_SUCCESS; } 解释&#xff…

高中数学:不等式

一、性质 1、同向可加性 2、同向同正可乘 3、正数乘方开方&#xff08;n∈Z&#xff0c;n≥2&#xff09; 常见题型 1、比较大小 分式比较大小&#xff0c;先去分母作差法比较大小带根号的无理数比较大小&#xff0c;直接两边开方因式分解&#xff08;较难&#xff09; 2、…

PhP+vue企业原材料采购系统_cxg0o

伴随着我国社会的发展&#xff0c;人民生活质量日益提高。互联网逐步进入千家万户&#xff0c;改变传统的管理方式&#xff0c;原材料采购系统以互联网为基础&#xff0c;利用php技术&#xff0c;结合vue框架和MySQL数据库开发设计一套原材料采购系统&#xff0c;提高工作效率的…

Vue项目创建和nodejs使用

Vue项目创建和nodejs使用 一、环境准备1.1.安装 node.js【下载历史版本node-v14.21.3-x64】1.2.安装1.3.检查是否安装成功&#xff1a;1.4.在Node下新建两个文件夹 node_global和node_cache并设置权限1.5.配置npm在安装全局模块时的路径和缓存cache的路径1.6.配置系统变量&…

L2-015 互评成绩

一、题目 二、解题思路 去掉一个最高分和一个最低分&#xff1a;在输入的时候找出每个同学的最大值和最小值&#xff0c;index1[n],index2[n] 两个数组分别记录每个同学的最大值和最小值对应的下标。注意可能会有多个最大值或有多个最小值&#xff0c;也可能最大值和最小值相同…

C# EventHandler<T> 示例

新建一个form程序&#xff0c;在调试窗口输出执行过程&#xff1b; 为了使用Debug.WriteLine&#xff0c;添加 using System.Diagnostics; using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using S…

探索IDE的世界:什么是IDE?以及适合新手的IDE推荐

引言 在编程的世界里&#xff0c;集成开发环境&#xff08;IDE&#xff09;是我们日常工作的重要工具。无论是初学者还是经验丰富的开发者&#xff0c;一个好的IDE都能极大地提高我们的编程效率。那么&#xff0c;什么是IDE呢&#xff1f;对于新手来说&#xff0c;又应该选择哪…

HCIA-HarmonyOS设备开发认证V2.0-轻量系统内核基础-事件event

目录 一、事件基本概念二、事件运行机制三、事件开发流程四、事件使用说明五、事件接口坚持就有收获 一、事件基本概念 事件是一种实现任务间通信的机制&#xff0c;可用于实现任务间的同步&#xff0c;但事件通信只能是事件类型的通信&#xff0c;无数据传输。一个任务可以等…

【学网攻】 第(27)节 -- HSRP(热备份路由器协议)

系列文章目录 目录 系列文章目录 文章目录 前言 一、HSRP(热备份路由器协议)是什么&#xff1f; 二、实验 1.引入 实验目标 实验背景 技术原理 实验步骤 实验设备 实验拓扑图 实验配置 实验验证 文章目录 【学网攻】 第(1)节 -- 认识网络【学网攻】 第(2)节 -- 交…

ICLR 2024 | Harvard FairSeg:第一个研究分割算法公平性的大型医疗分割数据集

近年来&#xff0c;人工智能模型的公平性问题受到了越来越多的关注&#xff0c;尤其是在医学领域&#xff0c;因为医学模型的公平性对人们的健康和生命至关重要。高质量的医学公平性数据集对促进公平学习研究非常必要。现有的医学公平性数据集都是针对分类任务的&#xff0c;而…

掌握C语言文件操作:从入门到精通的完整指南!

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C语言学习 贝蒂的主页&#xff1a;Betty‘s blog 1. 什么是文件 文件其实是指一组相关数据的有序集合。这个数据集有一个名称&a…

高校危化试剂管理:Java与SpringBoot的革新

✍✍计算机编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java实战 |…

Java学习第十二节之可变参数和递归

可变参数 package method;import javax.swing.*;public class Demo04 {public static void main(String[] args) {//调用可变参数的方法printMax(34,3,3,2,56.5);printMax(new double[]{1, 2, 3});}public static void printMax(double...numbers) {if (numbers.length 0) {Sy…

腾讯云4核8G服务器能支持多少人访问?

腾讯云4核8G服务器支持多少人在线访问&#xff1f;支持25人同时访问。实际上程序效率不同支持人数在线人数不同&#xff0c;公网带宽也是影响4核8G服务器并发数的一大因素&#xff0c;假设公网带宽太小&#xff0c;流量直接卡在入口&#xff0c;4核8G配置的CPU内存也会造成计算…

autojs通过正则表达式获取带有数字的text内容

视频连接 视频连接 参考 参考 var ctextMatches(/\d/).findOne()console.log("当前金币"c.text()) // 获取当前金币UiSelector.textMatches(reg) reg {string} | {Regex} 要满足的正则表达式。 为当前选择器附加控件"text需要满足正则表达式reg"的条件。 …

【Linux系统 04】OpenEuler配置

目录 一、镜像文件下载 二、配置静态IP 三、启动SSH连接 四、免密登录 五、安装常用软件 一、镜像文件下载 官方下载地址&#xff1a;openEuler下载 | 欧拉系统ISO镜像 | openEuler社区官网 选择一个版本&#xff0c;lopenEuler通常有两种版本&#xff1a; 创新版&…

使用Vue.js输出一个hello world

导入vue.js <script src"https://cdn.jsdelivr.net/npm/vue2/dist/vue.js"></script> 创建一个标签 <div id"app">{{message}}</div> 接管标签内容&#xff0c;创建vue实例 <script type"text/javascript">va…

H12-821_31

31.下面是一台路由器的部分配置,关于该配置描述正确的是: A.源地址为1.1.1.1的数据包匹配第一条ACL语句rule 0,匹配规则为允许 B.源地址为1.1.1.3的数据包匹配第三条ACL语句rule 2,匹配规则为拒绝 C.源地址为1.1.1.4的数据包匹配第四条ACL语句rule 3,匹配规则为允许 D.源地址为…