(动手学习深度学习)第7章 批量规范化(Batch Normalization)

BN

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 批量归一化固定小批量中的均值和方差,然后学习出适合的偏移和缩放。
  • 可以加速收敛速度,但一般不改变模型精度。

BN代码手动实现

  1. 导入相关库
import torch
from torch import nn
from d2l import torch as d2l
  1. 定义BN层
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    """
    :param X: 输入数据
    :param gamma: γ
    :param beta: β
    :param moving_mean: 全局均值
    :param moving_var: 全局方差
    :param eps: ε
    :param momentum: 冲量:用来更新或固定常量
    :return:输出数据, 全局均值, 全局方差
    """
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况:计算特征维度上的均值和方差
            mean = X.mean(dim=0)
            var = ((X-mean)**2).mean(dim=0)
        else:
            # 使用二维卷积层情况:计算通道维上(axis=1)的均值和方差
            # 这里需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim = (0, 2, 3), keepdim=True)
            var = ((X - mean)**2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var

    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):

    def __init__(self, num_features, num_dims):
        """
        :param num_features: 全连接层的输出数量或卷积层的输出通道数
        :param num_dims: 2表示全连接层,4表示卷积层
        """
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸参数和偏移参数,其分别初始化为1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上, 将moving_mean和moving_var复制到X所在的显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean, self.moving_var,eps=1e-5, momentum=0.9
        )
        return Y
  1. 应用BN与LeNet模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), BatchNorm(6, num_dims=4), nn.Sigmoid(), # [1, 6, 28, 28]
    nn.AvgPool2d(2, stride=2),  # [1, 6, 14, 14]

    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),  # [1, 16, 10, 10]
    nn.AvgPool2d(2, stride=2),  # [1, 16, 7, 7]

    nn.Flatten(),  # [1, 16*5*5]

    nn.Linear(16*5*5, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),  # [1, 400] -->[1, 120]
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),  # [1, 120] --> [1, 84]
    nn.Linear(84, 10)  # [1, 82] --> [1, 10]
)
  1. 查看模型
X = torch.randn((1, 1, 28, 28))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape: \t', X.shape)
  1. 加载数据集
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
  1. 训练模型
import time
lr, num_epochs = 1.0, 10
start = time.perf_counter()
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
end = time.perf_counter()
print("运行耗时 %.4f s" % (end-start))

在这里插入图片描述

在这里插入图片描述
7. 查看拉伸参数gamma和偏移参数beta

net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1, ))

在这里插入图片描述

BN代码简洁实现

  • 修改模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2),
    nn.BatchNorm2d(6), nn.Sigmoid(),  # [1, 6, 28, 28]
    nn.AvgPool2d(2, stride=2),  # [1, 6, 14, 14]

    nn.Conv2d(6, 16, kernel_size=5),
    nn.BatchNorm2d(16), nn.Sigmoid(),  # [1, 16, 10, 10]
    nn.AvgPool2d(2, stride=2),  # [1, 16, 5, 5]

    nn.Flatten(),  # [1, 16*5*5]

    nn.Linear(16 * 5 * 5, 120),
    nn.BatchNorm1d(120), nn.Sigmoid(),  # [1, 120]
    nn.Linear(120, 84),
    nn.BatchNorm1d(84), nn.Sigmoid(),  # [1, 84]
    nn.Linear(84, 10))  # [1, 10]
  • 训练模型
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/127407.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【FGPA】Verilog:移位寄存器 | 环形计数器 | 4bit移位寄存器的实现 | 4bit环形计数器的实现

目录 Ⅰ. 理论部分 0x00 移位寄存器(Shift Register) 0x01 环形计数器(Ring Counter) Ⅱ. 实践部分 0x00 移位寄存器(4-bit) 0x01 四位环形寄存器(4-bit) Ⅰ. 理论部分 0x00 …

如何在3dMax中使用MaxScript在视口中显示数据?

如何在3dMax中使用MaxScript在视口中显示数据? 详细的教程指南,介绍如何使用MaxScript在视口中直接显示对象名称、坐标和顶点索引等信息。 在本教程中,您将学习如何借助MaxScript在视口中直接显示对象信息或数据。 本教程介绍了如何显示简单的…

LeetCode算法题解(回溯、难点)|LeetCode51. N 皇后

LeetCode51. N 皇后 题目链接:51. N 皇后 题目描述: 按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。…

[PyTorch][chapter 61][强化学习-免模型学习 off-policy]

前言: 蒙特卡罗的学习基本流程: Policy Evaluation : 生成动作-状态轨迹,完成价值函数的估计。 Policy Improvement: 通过价值函数估计来优化policy。 同策略(one-policy):产生 采样轨迹的策略 和要改…

【K8s集群离线安装-kubeadm】

1、kubeadm概述 kubeadm是官方社区推出的一个用于快速部署kubernetes集群的工具。这个工具能通过两条指令快速完成一个kubernetes集群的部署。 2、环境准备 2.1 软件环境 软件版本操作系统CentOS 7Docker19.03.13K8s1.23 2.2 服务器 最小硬件配置:2核CPU、2G内存…

19.5 Boost Asio 传输结构体

同步模式下的结构体传输与原生套接字实现方式完全一致,读者需要注意的是在接收参数是应该使用socket.read_some函数读取,发送参数则使用socket.write_some函数实现,对于套接字的解析同样使用强制指针转换的方法。 服务端代码如下所示 #incl…

「Verilog学习笔记」4位数值比较器电路

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点,刷题网站用的是牛客网 分析 这里要注意题目的“门级描述方式”,所以我们只能使用基本门电路:&,|,!,^,^~。 具体实现思路:通过真值表得出Y0 Y1 Y2的逻辑表达…

Vue3使用vue-print-nb插件打印功能

插件官网地址https://www.npmjs.com/package/vue-print-nb 效果展示: 打印效果 根据不同的Vue版本安装插件 //Vue2.0版本安装方法 npm install vue-print-nb --save pnpm install vue-print-nb --save yarn add vue-print-nb//Vue3.0版本安装方法: npm install vue3…

优思学院|CTP和CTQ是什么?有什么区别?

CTQ 关键质量特性 CTQ是在六西格玛管理中常用的重要词汇,所以很多不同界别的人仕都可能听过,CTQ的意思是关键质量特性,Critical To Quality 的缩写。 六西格玛管理提倡的方法是通过客户的声音 (Voice of customer-VOC) ,然后把它…

绝对力作:解锁string的所有关键接口,万字深度解析!

W...Y的主页 😊 🍔前言: 通过博主的上篇文章,我相信大家已经认识了STL并且已经迫不及待想学习了,现在我们就走近STL的第一种类——string。 目录 为什么学习string类? C语言中的字符串 标准库中的str…

使用 Socks5 来劫持 HTTPS(TCP-TLS) 之旅

MITM 劫持的过程中,HTTP 协议并不是唯一选择。 实际在 MITM 使用过程中,BurpSuite 和 Yakit 提供的交互式劫持工具只能劫持 HTTP 代理的 TLS 流量;但是这样是不够的,有时候我们并不能确保 HTTP 代理一定生效,或者说特…

【js逆向实战】某sakura动漫视频逆向

写在前面 再写一个逆向实战,后面写点爬虫程序来实现一下。 网站简介与逆向目标 经典的一个视频网站,大多数视频网站走的是M3U8协议,就是一个分段传输,其实这里就有两个分支。 通过传统的m3u8协议,我们可以直接进行分…

python回文日期 并输出下一个ABABBABA型回文日期

题目: 输入: 输入包含一个八位整数N,表示日期 对于所有的测评用例,10000101 ≤N≤89991231,保证N是一个合法日期的8位数表示 输出: 输出两行,每行一个八位数。第一行表示下一个回文日期第二…

【论文阅读】DALL·E: Zero-Shot Text-to-Image Generation

OpenAI第一代文本生成图片模型 paper:https://arxiv.org/abs/2102.12092 DALLE有120亿参数,基于自回归transformer,在2.5亿 图片-文本对上训练的。实现了高质量可控的text to image,同时也有zero-shot的能力。 DALL-E没有使用扩…

【腾讯云 HAI域探秘】探索AI绘画之路:利用腾讯云HAI服务打造智能画家

目录 前言1 使用HAI服务作画的步骤1.1 注册腾讯云账户1.2 创建算力服务器1.3 进入模型管理界面1.4 汉化界面1.5 探索AI绘画 2 模型参数的含义和调整建议2.1 模型参数的含义和示例2.2 模型参数的调整建议 3 调整参数作画的实践和效果3.1 实践说明3.2 实践效果13.3 实践效果23.4 …

专门为Web应用程序提供安全保护的设备-WAF

互联网网站面临着多种威胁,包括网络钓鱼和人为的恶意攻击等。这些威胁可能会导致数据泄露、系统崩溃等严重后果。 因此,我们需要采取更多有效的措施来保护网站的安全。其中WAF(Web application firewall,Web应用防火墙&#xff0…

网站接口测试记录

1.被测试服务器端口输入htop指令进行cpu监控 2.测试机器安装宝塔-》我的工具-》进行网站测试 访问地址:https://www.bt.cn/bbs/thread-52772-1-1.html

Spring Cloud智慧工地管理平台源码,智慧工地APP源码,实现对劳务人员、施工进度、工地安全、材料设备、环境监测等方面的实时监控和管理

智慧工地管理平台源码,智慧工地APP源码, 智慧工地管理平台实现对人员管理、施工进度、安全管理、材料管理、设备管理、环境监测等方面的实时监控和管理,提高施工效率和质量,降低安全风险和环境污染。智慧工地平台支持项目级、公司…

SpringCloud——负载均衡——Ribbon

负载均衡分为集中式LB(Nginx实现)和进程内LB(Ribbon)。 Ribbon简单来说就是负载均衡RestTemplate调用。 1.Ribbon在工作中分成两步 1.先选择EurekaServer,它优先选择在同一个区域内负载较少的EurekaServer。 2.在根据用户指定的策略,从服务注册的列表…

Go 什么是循环依赖

Go 中的循环依赖是指两个或多个包之间相互引用,形成了一个循环依赖关系。这种情况下,包 A 依赖包 B,同时包 B 也依赖包 A,导致两个包之间无法明确地确定编译顺序,从而可能引发编译错误或其他问题。循环依赖是 Go 中需要…