【深度学习】5-3 与学习相关的技巧 - Batch Normalization

如果为了使各层拥有适当的广度,“强制性”地调整激活值的分布会怎样呢?实际上,Batch Normalization 方法就是基于这个想法而产生的

为什么Batch Norm这么惹人注目呢?因为Batch Norm有以下优点:

  • 可以使学习快速进行(可以增大学习率)。
  • 不那么依赖初始值(对于初始值不用那么神经质) 。
  • 抑制过拟合(降低Dropout等的必要性)。

Batch Norm的思路是调整各层的激活值分布使其拥有适当的广度。为此,要向神经网络中插入对数据分布进行正规化的层,即Batch Normalization层(下文简称Batch Norm层)
在这里插入图片描述
Batch Norm,顾名思义,以进行学习时的mini-batch为单位,按mini-batch进行正规化。具体而言,就是进行使数据分布的均值为0、方差为1的正规化。用数学式表示的话,如下:
在这里插入图片描述

这里对mini-batch的m个输人数据的集合B求均值方差。然后,对输人数据进行均值为0、方差为1(合适的分布)的正规化。
这个式子所做的是将mini-batch的输人数据变换为均值为0,方差为1的数据。通过将这个处理插入到激活函数的前面(或者后面),可以减少数据分布的偏向
接着,Batch Norm层会对正规化后的数据进行缩放和平移的变换用数学式可以如下表示。
在这里插入图片描述

这里,γ和β是参数。一开始γ=1,β=0,然后再通过学习调整到合适的值。
上面就是Batch Norm的算法。这个算法是神经网络上的正向传播。

用计算图表示如下:
在这里插入图片描述

Batch Norm的反向传播
Batch Norm实现类

class BatchNormalization:
    """
    http://arxiv.org/abs/1502.03167
    """
    def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None):
        self.gamma = gamma
        self.beta = beta
        self.momentum = momentum
        self.input_shape = None # Conv层的情况下为4维,全连接层的情况下为2维  

        # 测试时使用的平均值和方差
        self.running_mean = running_mean
        self.running_var = running_var  
        
        # backward时使用的中间数据
        self.batch_size = None
        self.xc = None
        self.std = None
        self.dgamma = None
        self.dbeta = None

    def forward(self, x, train_flg=True):
        self.input_shape = x.shape
        if x.ndim != 2:
            N, C, H, W = x.shape
            x = x.reshape(N, -1)

        out = self.__forward(x, train_flg)
        
        return out.reshape(*self.input_shape)
            
    def __forward(self, x, train_flg):
        if self.running_mean is None:
            N, D = x.shape
            self.running_mean = np.zeros(D)
            self.running_var = np.zeros(D)
                        
        if train_flg:
            mu = x.mean(axis=0)
            xc = x - mu
            var = np.mean(xc**2, axis=0)
            std = np.sqrt(var + 10e-7)
            xn = xc / std
            
            self.batch_size = x.shape[0]
            self.xc = xc
            self.xn = xn
            self.std = std
            self.running_mean = self.momentum * self.running_mean + (1-self.momentum) * mu
            self.running_var = self.momentum * self.running_var + (1-self.momentum) * var            
        else:
        	# 算法实现
            xc = x - self.running_mean
            xn = xc / ((np.sqrt(self.running_var + 10e-7)))
            
        out = self.gamma * xn + self.beta 
        return out

    def backward(self, dout):
        if dout.ndim != 2:
            N, C, H, W = dout.shape
            dout = dout.reshape(N, -1)

        dx = self.__backward(dout)

        dx = dx.reshape(*self.input_shape)
        return dx

	# 反向传播
    def __backward(self, dout):
        dbeta = dout.sum(axis=0)
        dgamma = np.sum(self.xn * dout, axis=0)
        dxn = self.gamma * dout
        dxc = dxn / self.std
        dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0)
        dvar = 0.5 * dstd / self.std
        dxc += (2.0 / self.batch_size) * self.xc * dvar
        dmu = np.sum(dxc, axis=0)
        dx = dxc - dmu / self.batch_size
        
        self.dgamma = dgamma
        self.dbeta = dbeta
        
        return dx

Batch Normalization的评估

现在我们使用Batch Norm层进行实验。首先,使用MNIST数据集,观察使用Batch Norm层和不使用Batch Norm层时学习的过程会如何变化,
代码如下:

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net_extend import MultiLayerNetExtend
from common.optimizer import SGD, Adam

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

# 减少学习数据
x_train = x_train[:1000]
t_train = t_train[:1000]

max_epochs = 20
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.01


def __train(weight_init_std):
    bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10, 
                                    weight_init_std=weight_init_std, use_batchnorm=True)
    network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
                                weight_init_std=weight_init_std)
    optimizer = SGD(lr=learning_rate)
    
    train_acc_list = []
    bn_train_acc_list = []
    
    iter_per_epoch = max(train_size / batch_size, 1)
    epoch_cnt = 0
    
    for i in range(1000000000):
        batch_mask = np.random.choice(train_size, batch_size)
        x_batch = x_train[batch_mask]
        t_batch = t_train[batch_mask]
    
        for _network in (bn_network, network):
            grads = _network.gradient(x_batch, t_batch)
            optimizer.update(_network.params, grads)
    
        if i % iter_per_epoch == 0:
            train_acc = network.accuracy(x_train, t_train)
            bn_train_acc = bn_network.accuracy(x_train, t_train)
            train_acc_list.append(train_acc)
            bn_train_acc_list.append(bn_train_acc)
    
            print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc))
    
            epoch_cnt += 1
            if epoch_cnt >= max_epochs:
                break
                
    return train_acc_list, bn_train_acc_list


# 3.绘制图形==========
weight_scale_list = np.logspace(0, -4, num=16)
x = np.arange(max_epochs)

for i, w in enumerate(weight_scale_list):
    print( "============== " + str(i+1) + "/16" + " ==============")
    train_acc_list, bn_train_acc_list = __train(w)
    
    plt.subplot(4,4,i+1)
    plt.title("W:" + str(w))
    if i == 15:
        plt.plot(x, bn_train_acc_list, label='Batch Normalization', markevery=2)
        plt.plot(x, train_acc_list, linestyle = "--", label='Normal(without BatchNorm)', markevery=2)
    else:
        plt.plot(x, bn_train_acc_list, markevery=2)
        plt.plot(x, train_acc_list, linestyle="--", markevery=2)

    plt.ylim(0, 1.0)
    if i % 4:
        plt.yticks([])
    else:
        plt.ylabel("accuracy")
    if i < 12:
        plt.xticks([])
    else:
        plt.xlabel("epochs")
    plt.legend(loc='lower right')
    
plt.show()

运行结果如下:
在这里插入图片描述

从运行结果可以看到使用Batch Norm后,学习进行得更快了。
综上,通过使用Batch Norm,可以推动学习的进行。并且,对权重初始值变得健壮(表示不那么依初始值) Batch Norm具备如此优良的性质,一定能应用在更多场合中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/30936.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Web安全——HTML基础

HTML 一、对于前端以及后端的认识以及分析二、HTML认知1、网页的组成2、浏览器3、Web标准 三、简单的HTML页面架构四、HTML常见标签1、meta标签2、标题标签3、文本属性4、form表单5、a 标签6、锚文本7、img 标签8、table 表格9、列表标签9.1、无序列表9.2、有序列表 10、框架的…

Java性能权威指南-总结14

Java性能权威指南-总结14 堆内存最佳实践对象生命周期管理对象重用 堆内存最佳实践 对象生命周期管理 在很大程度上&#xff0c;Java会尽量减轻开发者投入到对象生命周期管理上的精力&#xff1a;开发者在需要的时候创建对象&#xff0c;当不再需要这些对象时&#xff0c;它们…

Java 被挤出前三。。

TIOBE 2023 年 06 月份的编程语言排行榜已经公布&#xff0c;官方的标题是&#xff1a;Python 还会保持第一吗&#xff1f;&#xff08;Will Python remain number 1?&#xff09; 在过去的 5 年里&#xff0c;Python 已经 3 次获得 TIOBE 指数年度大奖&#xff0c;这得益于…

浅谈C++|引用篇

目录 引入 一.引用的基本使用 (1)引用的概念&#xff1a; (2)引用的表示方法 (3)引用注意事项 (4)引用权限 二.引用的本质 三.引用与函数 (1)引用做函数参数 (2)引用做函数返回值 四.常量引用 五.引用与指针 引入 绰号&#xff0c;又称外号&#xff0c;是人的本名以外…

【k8s系列】一分钟搭建MicroK8s Dashboard

本文基于上一篇文章的内容进行Dashboard搭建&#xff0c;如果没有看过上一篇的同学请先查阅上一篇文章 k8s系列】使用MicroK8s 5分钟搭建k8s集群含踩坑经验 使用MicroK8s搭建Dashboard很简单&#xff0c;只需要在Master节点按照以下几步操作 1.启用Dashboard插件 microk8s en…

【软件工程】软件工程期末考试复习题

软件工程期末考试试题及参考答案 一、单向选择题 1、软件的发展经历了&#xff08;D&#xff09;个发展阶段。 一二三四 2、需求分析的任务不包括&#xff08;B&#xff09;。 问题分析系统设计需求描述需求评审。 3、一个软件的宽度是指其控制的&#xff08;C&#xff0…

[进阶]TCP通信综合案例:群聊

代码演示如下&#xff1a; 客户端&#xff1a; public class Client {public static void main(String[] args) throws Exception{System.out.println("客户端开启&#xff01;");//1.创建Socket对象&#xff0c;并同时请求与服务端程序的连接。Socket socket new…

新人拿到一个web项目如何使用idea发布运行

本文描述的是一个新手&#xff0c;拿到一个web项目&#xff0c;使用idea如何发布运行。项目中没有非常复杂的元素&#xff0c;只是试着描述应该如何配置相关内容。 内容描述前提&#xff0c;首先请您确认tomcat已经安装&#xff0c;其次确认jdk已经安装&#xff0c;并明确他们在…

STM32速成笔记—GPIO

文章目录 一、什么是GPIO二、GPIO的输入/输出模式三、GPIO初始化配置四、Boot引脚五、一些特殊的GPIO六、点亮LED1. 硬件电路2. 拉高/拉低GPIO3. 程序设计 七、GPIO的位带操作 一、什么是GPIO GPIO(英语:General-purpose input/output)&#xff0c;通用型之输入输出的简称&…

Java与SpringBoot对redis的使用方式

目录 1.Java连接redis 1.1 使用Jedis1.2 使用连接池连接redis1.3 java连接redis集群模式 2.SpringBoot整合redis 2.1 StringRedisTemplate2.2 RedisTemplate 1.Java连接redis redis支持哪些语言可以操作 &#xff08;去redis官网查询&#xff09; 1.1 使用Jedis (1)添加jedis…

【数字图像处理】2.几何变换

目录 什么是几何变换&#xff1f; 为什么要对图像进行几何变换&#xff1f; 2.1 仿射变换&#xff08;二维&#xff09; 2.2 投影变换&#xff08;三维&#xff09; 2.3 极坐标变换 2.3.1 将笛卡尔坐标转化为极坐标 2.3.2 将极坐标转换为笛卡尔坐标 2.3.3 利用极坐标变…

汇编学习教程:寻址大总结

前言 在上篇博文中&#xff0c;我们主要学习了一个全新的寄存器&#xff1a;bp。bp 寄存器在功能和使用上与 bx 有着异曲同工之妙&#xff0c;只不过两人绑定的服务对象不同&#xff1a;bx 默认绑定的是 DS 段寄存器&#xff0c;而 bp 默认绑定的是 SS 段寄存器。bx 和 bp 有着…

Unity之透明度混合与ps的透明度混合计算结果不一致

一、问题 前段时间学习shader时发现了一个问题&#xff0c;一张纯红色透明度为128的图片叠加在一张纯绿色的图片上得出的结果与ps中的结果不一致。网上查找了ps中的透明混合的公式为 color A.rgb*A.alpha B.rgb*(1-A.alpha)。自己使用代码在unity中计算了一下结果总是不对。…

【Java基础学习打卡09】JRE与JDK

目录 前言一、JRE二、JDK三、JDK、JRE和JVM关系总结 前言 本文将介绍JRE、JDK是什么&#xff0c;以及JDK、JRE和JVM关系三者之间的关系。 一、JRE JRE全称为Java Runtime Environment&#xff0c;是Java应用程序的运行时环境。JRE包括Java虚拟机&#xff08;JVM&#xff09;、…

车辆救援道路救援预约汽修托运小程序

道路救援&#xff1a;指汽车道路紧急救援&#xff0c;为故障车主提供包括诸如&#xff1a;拖吊、换水、充电、换胎、送油以及现场小修等服务(Road-Side Service)&#xff1b; 同时也指交通事故道路救援&#xff0c;包括伤员救治、道路疏导等。 随着我国巨大的汽车拥有量&…

基础篇:新手使用vs code新建go项目(从0开始到运行)

学习新语言&#xff0c;搭建新环境。在网上找了一些教程&#xff0c;感觉还是写一个比较详细的方便以后自己使用。其实vs code没有新建项目这个功能&#xff0c;具体怎么运行go语言的项目请看下文。 一、下载GO安装包 1.点击go安装包下载链接下载相应的版本&#xff08;本次下…

了解 Dockerfile 和搭建 Docker 私有仓库:让容器化部署变得更简单

目录 1、Dockerfile 1.1什么是Dockerfile 1.2常用命令 1.3使用脚本创建镜像 2、Docker私有仓库 2.1私有仓库介绍&#xff1a; 2.2私有仓库搭建与配置 2.3上传镜像到私有仓库&#xff1a; 1、Dockerfile 1.1什么是Dockerfile Dockerfile是由一些列命令和参数构成的脚本…

《网络安全0-100》安全事件案例

网络安全事件案例分析 2017年Equifax数据泄露事件 Equifax是美国一家信用评级机构&#xff0c;2017年9月&#xff0c;该公司披露发生了一起重大的数据泄露事件&#xff0c;涉及1.43亿美国人的个人信息&#xff0c;包括姓名、出生日期、社会安全号码等敏感信息。经过调查&#…

【数据分析】如何使用docker部署程序并移植(算法、接口)

原文作者&#xff1a;我辈李想 版权声明&#xff1a;文章原创&#xff0c;转载时请务必加上原文超链接、作者信息和本声明。 文章目录 前言一、Docker的基本使用1.安装Docker2.列出本地镜像3.获取镜像,创建本地ubuntu:13.10镜像4.查找镜像5.删除本地镜像6.创建自定义镜像7.镜像…

第一章 基础算法(二)——高精度,前缀和与差分

文章目录 高精度运算高精度加法高精度减法高精度乘法高精度除法 前缀和二维前缀和 差分二维差分 高精度练习题791. 高精度加法792. 高精度减法793. 高精度乘法794. 高精度除法 前缀和练习题795. 前缀和796. 子矩阵的和 差分练习题797. 差分798. 差分矩阵 高精度运算 两个大数做…