深度学习——批标准化Batch Normalization

什么是批标准化?

批标准化(Batch Normalization)是深度学习中常用的一种技术,旨在加速神经网络的训练过程并提高模型的收敛速度。
批标准化通过在神经网络的每一层中对输入数据进行标准化来实现。具体而言,对于每个输入样本,在每一层的前向传播过程中,都会计算其均值和方差,并使用批量内的均值和方差对输入进行标准化。标准化后的数据会经过缩放和平移操作,使得网络可以学习到适合当前任务的特定数据分布。这样做的好处包括:
1.收敛速度更快:批标准化有助于避免梯度消失和梯度爆炸问题,使得神经网络在训练过程中更快地收敛。
2.允许更高的学习率:标准化输入可以使学习率的选择更加宽松,使得学习过程更加稳定。
3.正则化作用:批标准化在一定程度上具有正则化的效果,有助于防止过拟合。
4.不那么依赖初始化:由于标准化的存在,对网络的初始权重设置并不像传统网络那样敏感,这简化了网络的初始化过程。

对比使用批标准化和不使用批标准化

import torch
from torch import nn
from torch.nn import init
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np

# 用于可复现
# torch.manual_seed(1)    # reproducible
# np.random.seed(1)

# Hyper parameters
# 样本点
N_SAMPLES = 2000
# 批大小
BATCH_SIZE = 64
# 轮次
EPOCH = 12
# 学习率
LR = 0.03
# 隐藏层层数
N_HIDDEN = 8
# 激活函数
ACTIVATION = torch.tanh
B_INIT = -0.2   # use a bad bias constant initializer

# training data
# 生成-7到10之间的N_SAMPLES个值的等差数列,并将其转化为一个二维列向量
x = np.linspace(-7, 10, N_SAMPLES)[:, np.newaxis]
# 生成一个均值为0,标准差为2的和x相同形状的噪声数据
noise = np.random.normal(0, 2, x.shape)
# 生成x对应的y值
y = np.square(x) - 5 + noise

# test data
test_x = np.linspace(-7, 10, 200)[:, np.newaxis]
noise = np.random.normal(0, 2, test_x.shape)
test_y = np.square(test_x) - 5 + noise

train_x = torch.from_numpy(x).float()
train_y = torch.from_numpy(y).float()
test_x = torch.from_numpy(test_x).float()
test_y = torch.from_numpy(test_y).float()

train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)

# show data
# plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
# plt.scatter(test_x.numpy(), test_y.numpy(), c='blue', s=50, alpha=0.2, label='test')
# plt.legend(loc='best')



class Net(nn.Module):
    def __init__(self, batch_normalization=False):
        super(Net, self).__init__()
        # 是否进行批标准化
        self.do_bn = batch_normalization
        # 全连接层的列表
        self.fcs = []
        # 批标准化层的列表
        self.bns = []
        self.bn_input = nn.BatchNorm1d(1, momentum=0.5)   # for input data

        for i in range(N_HIDDEN):               # build hidden layers and BN layers
            # 如果是第一层,输入神经元个数为1,其余为10个
            input_size = 1 if i == 0 else 10
            # 全连接层
            fc = nn.Linear(input_size, 10)
            # 将全连接层重新命名然后设置为类属性
            setattr(self, 'fc%i' % i, fc)       # IMPORTANT set layer to the Module
            # 对全连接层的参数进行初始化
            self._set_init(fc)                  # parameters initialization
            # 添加到列表中
            self.fcs.append(fc)
            if self.do_bn:
                bn = nn.BatchNorm1d(10, momentum=0.5)
                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module
                self.bns.append(bn)

        self.predict = nn.Linear(10, 1)         # output layer
        self._set_init(self.predict)            # parameters initialization

    def _set_init(self, layer):
        init.normal_(layer.weight, mean=0., std=.1)
        init.constant_(layer.bias, B_INIT)

    # 前向传播
    def forward(self, x):
        pre_activation = [x]
        if self.do_bn:
            x = self.bn_input(x)     # input batch normalization
        layer_input = [x]
        for i in range(N_HIDDEN):
            x = self.fcs[i](x)
            pre_activation.append(x)
            if self.do_bn: x = self.bns[i](x)   # batch normalization
            x = ACTIVATION(x)
            layer_input.append(x)
        out = self.predict(x)
        # 返回预测值、每个隐藏层的输入、激活函数的输出
        return out, layer_input, pre_activation


nets = [Net(batch_normalization=False), Net(batch_normalization=True)]

# print(*nets)    # print net architecture

# 优化器
opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]
# MSE作为损失函数
loss_func = torch.nn.MSELoss()


def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
    for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
        [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
        if i == 0:
            p_range = (-7, 10);the_range = (-7, 10)
        else:
            p_range = (-4, 4);the_range = (-1, 1)
        ax_pa.set_title('L' + str(i))
        ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
        ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
        for a in [ax_pa, ax, ax_pa_bn, ax_bn]: a.set_yticks(());a.set_xticks(())
        ax_pa_bn.set_xticks(p_range);ax_bn.set_xticks(the_range)
        axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')
    plt.pause(0.01)


if __name__ == "__main__":
    f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))
    # 开启动态绘制
    plt.ion()  # something about plotting
    plt.show()

    # training
    losses = [[], []]  # recode loss for two networks

    for epoch in range(EPOCH):
        print('Epoch: ', epoch)
        layer_inputs, pre_acts = [], []
        # 训练两个网络
        for net, l in zip(nets, losses):
            net.eval()              # set eval mode to fix moving_mean and moving_var
            pred, layer_input, pre_act = net(test_x)
            l.append(loss_func(pred, test_y).data.item())
            layer_inputs.append(layer_input)
            pre_acts.append(pre_act)
            net.train()             # free moving_mean and moving_var
        plot_histogram(*layer_inputs, *pre_acts)     # plot histogram

        for step, (b_x, b_y) in enumerate(train_loader):
            for net, opt in zip(nets, opts):     # train for each network
                # 获取到预测值
                pred, _, _ = net(b_x)
                # 计算loss
                loss = loss_func(pred, b_y)
                # 梯度清零
                opt.zero_grad()
                # 误差反向传播
                loss.backward()
                # 逐步优化网络参数
                opt.step()    # it will also learns the parameters in Batch Normalization

    # 关闭动态绘制
    plt.ioff()

    # plot training loss
    # 绘制loss图
    plt.figure(2)
    plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
    plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
    plt.xlabel('step')
    plt.ylabel('test loss')
    plt.ylim((0, 2000))
    plt.legend(loc='best')

    # evaluation
    # set net to eval mode to freeze the parameters in batch normalization layers
    [net.eval() for net in nets]    # set eval mode to fix moving_mean and moving_var
    preds = [net(test_x)[0] for net in nets]
    plt.figure(3)
    # 测试拟合效果
    plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')
    plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')
    plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')
    plt.legend(loc='best')
    plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/46056.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue3+elementplus后台管理系统,实现侧边栏菜单显示到主内容区域

目录 1 创建页面2 设置路由3 修改首页4 首页的完整代码总结 我们已经使用vue3和elmentplus初步搭建了首页,上一篇中有个问题没解决,就是在侧边栏导航功能里,如果点击菜单希望是在首页打开页面而不是跳转到新页面。以下是我们希望实现的效果 这…

AI学习笔记二:YOLOV5环境搭建及测试全过程

若该文为原创文章,转载请注明原文出处。 记录yolov5从环境搭建到测试全过程。 一、运行环境 1、系统:windows10 (无cpu) 2、yolov5版本:yolov5-5.0 3、python版本:py3.8 在创建虚拟环境前需要先把miniconda3和py…

GUI自动化测试进阶:页面对象模式

本文介绍的是页面对象设计模式及其常见的滥用继承的错误。 本文和语言无关,但作者主要使用python和java。本文假设读者已经具有了一定的python或java基础,知道类和方法是什么。 如果完全没有这方面的基础,请看我的《测试人员如何学Python》。…

【图像分类】基于LIME的CNN 图像分类研究(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码实现 💥1 概述 基于LIME(Local Interpretable Model-Agnostic Explanations)的CNN图像分类研究是一种用于解释CNN模型的方法。LIME是一…

【UE5 多人联机教程】04-加入游戏

效果 步骤 1. 新建一个控件蓝图,父类为“USC_Button_Standard” 控件蓝图命名为“UMG_Item_Room”,用于表示每一个搜索到的房间的界面 打开“UMG_Item_Room”,在图表中新建一个变量,命名为“Session” 变量类型为“蓝图会话结果…

自恢复保险丝(PPTC)的金属材料说明

保险丝大家都是知道的,但保险丝当中的自恢复保险丝(PPTC)可能就不太了解的。 其实PPTC自恢复保险丝与大家所认识的保险丝一样,都是起到限流作用,达到电路防护效果。简单来说就是一旦电路中的电流超过所规定的电流时&am…

【数据结构】二叉树详解(3)

⭐️ 前言 ✨ 往期链接:【数据结构】二叉树详解(1) 在第一篇二叉树文章中,我们探讨了二叉树的链式结构定义与实现。二叉的遍历包含( 前序/中序/后序遍历 )及代码实现和递归流程图的详细讲解。还有一些二叉树的其他接口定义与实现,包含 Binar…

【vue3】vue3的一般项目结构、成功显示自己的vue3页面

一、vue3的一般项目结构 Vue 3并没有规定特定的项目结构,因此您可以根据项目的需求和个人偏好来组织您的Vue 3项目。以下是一个常见的Vue 3项目结构示例,供参考: your-project/|- public/| |- index.html # 应用程序的入口HTML文件…

【Matlab】基于粒子群优化算法优化BP神经网络的时间序列预测(Excel可直接替换数据)

【Matlab】基于粒子群优化算法优化BP神经网络的时间序列预测(Excel可直接替换数据) 1.模型原理2.数学公式3.文件结构4.Excel数据5.分块代码5.1 fun.m5.2 main.m 6.完整代码6.1 fun.m6.2 main.m 7.运行结果 1.模型原理 基于粒子群优化算法(Pa…

ubuntu 18.04 磁盘太满无法进入系统

安装了一个压缩包,装了一半提示磁盘空间少导致安装失败。我也没在意,退出虚拟机打算扩展硬盘。等我在虚拟机设置中完成扩展操作,准备进入虚拟机内部进行操作时,发现登录不进去了 shift 登入GUN GRUB设置项的问题 网上都是在开机…

持续贡献开源力量,棱镜七彩加入openKylin

近日,棱镜七彩签署 openKylin 社区 CLA(Contributor License Agreement 贡献者许可协议),正式加入openKylin 开源社区。 棱镜七彩成立于2016年,是一家专注于开源安全、软件供应链安全的创新型科技企业。自成立以来&…

Cesium态势标绘专题-圆角矩形(标绘+编辑)

标绘专题介绍:态势标绘专题介绍_总要学点什么的博客-CSDN博客 入口文件:Cesium态势标绘专题-入口_总要学点什么的博客-CSDN博客 辅助文件:Cesium态势标绘专题-辅助文件_总要学点什么的博客-CSDN博客 本专题没有废话,只有代码,代码中涉及到的引入文件方法,从上面三个链…

剑指offer41.数据流中的中位数

我一开始的想法是既然要找中位数,那肯定要排序,而且这个数据结构肯定要能动态的添加数据的,肯定不能用数组,于是我想到了用优先队列,它自己会排序都不用我写,所以addNum方法直接调用就可以,但是…

小创业公司死亡剧本

感觉蛮真实的;很多小创业公司没有阿里华为的命,却得了阿里华为的病。小的创业公司要想活无非以下几点: 1 现金流,现金流,现金流; 2 产品,找痛点,不要搞伪需求; 3 根据公司…

让婚礼策划展示小程序成为你的必备利器

在当今互联网时代,微信小程序已经成为了很多企业和个人展示自己产品和服务的重要渠道。如果你想学习微信小程序开发,下面将为你介绍一些基本步骤。 首先,你需要注册并登录一个第三方小程序制作平台,比如乔拓云平台。这些平台提供了…

Git-分布式版本控制工具

Git仓库:本地和远程 获取git仓库: 本地初始化Git仓库(创建空目录,右键git bansh,执行git init)远程仓库克隆,git clone 远程仓库地址 版本库:.git隐藏文件夹,储存配置信…

【SCI一区】互联燃料电池混合动力汽车通过信号交叉口的生态驾驶双层凸优化(Matlab代码实现)

目录 💥1 概述 1.2 电动车动力学方程 1.3 电池模型 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码、数据、文章讲解 💥1 概述 文献来源: 随着车辆互联性的出现,互联汽车 (CVs) 在增强道路安全、改…

Android平台如何实现第三方模块编码后(H.264/H.265/AAC/PCMA/PCMU)数据实时预览播放

技术诉求 我们在做GB28181设备对接模块和RTMP直播推送模块的时候,遇到这样的技术需求,设备(如执法记录仪)侧除了采集传统的摄像头外,还需要对接比如大疆等第三方数据源,确保按照GB28181规范和RTMP协议规范…

量化交易——python数据分析及可视化

该项目分为两个部分:一是数据计算,二是可视化,三是MACD策略 一、计算MACD 1、数据部分 数据来源:tushare 数据字段包含:日期,开盘价,收盘价,最低价,最高价&#xff0c…

C# 用队列实现栈

225 用队列实现栈 请你仅使用两个队列实现一个后入先出(LIFO)的栈,并支持普通栈的全部四种操作(push、top、pop 和 empty)。 实现 MyStack 类: void push(int x) 将元素 x 压入栈顶。 int pop() 移除并返…