暂退法(丢弃法)

       在深度学习中,丢弃法(Dropout)是一种常用的正则化技术,旨在减少模型的过拟合现象,可能会比之前的权重衰减(Weight Decay)效果更好。通过在训练过程中随机丢弃一部分神经元,可以有效地减少神经网络中的参数依赖性,增强模型的泛化能力。

一、丢弃法原理介绍

1、动机

       一个好的模型需要对输入数据的扰动鲁棒,也就是说,不管图片加入多少噪音,我也是能看清楚的。使用有噪音的数据等价于Tikhonov正则,正则使得权重值范围不会太大,避免一定的过拟合。与之前加入的噪音不一样,之前是固定噪音,丢弃法是随机噪音,丢弃法不是在输入加噪音,而是在层之间加入噪音,所以丢弃法也算是一个正则。

2、无偏差的加入噪音

       假如$x$是上一层到下一层的某一个输出(上一层输出向量的某一个元素)的话,对$x$加入噪音得到$x'$,我们希望加入噪音后不改变期望,即:

$ E\left[ x' \right] =x $

       丢弃法对上一层输出向量的每一个元素做如下扰动:

       此时这个元素的期望是不变的:

$ E\left[ x' \right] =p\cdot 0+\left( 1-p \right) \frac{x}{1-p}=x $

3、丢弃法的使用

       通常将丢弃法作用在隐藏全连接层的输出上。如图"MLP with one hidden layer"带有1个隐藏层和5个隐藏单元的多层感知机。当我们将暂退法应用到隐藏层,以$p$的概率将隐藏单元置为零时,结果可以看作一个只包含原始神经元子集的网络。比如在图"Hidden layer after dropout"中,删除了$h_2$$h_5$,因此输出的计算不再依赖于$h_2$$h_5$,并且它们各自的梯度在执行反向传播时也会消失。这样,输出层的计算不能过度依赖于$h_1, \ldots, h_5$的任何一个元素。

4、推理中的丢弃法

5、总结

  • 丢弃法将一些输出项随机置0来控制模型复杂度
  • 常作用在多层感知机的隐藏层输出上
  • 丢弃概率是控制模型复杂度的超参数

二、暂退法从零开始实现

1、定义dropout函数

       要实现单层的暂退法函数,我们从均匀分布$U[0, 1]$中抽取样本,样本数与这层神经网络的维度一致。然后我们保留那些对应样本大于$p$的节点,把剩下的丢弃。

       在下面的代码中,我们实现 `dropout_layer` 函数,该函数以`dropout`的概率丢弃张量输入`X`中的元素,如上所述重新缩放剩余部分:将剩余部分除以`1.0-dropout`。

import torch
from torch import nn
from d2l import torch as d2l

def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    # torch.rand(X.shape)生成了一个与输入张量X相同形状的随机数张量,其中的元素值在[0, 1)的区间内均匀分布。
    # (torch.rand(X.shape) > dropout)执行了一个逻辑判断,将随机数张量中大于dropout的元素置为True,小于等于dropout的元素置为False。
    # .float()将布尔型张量转换为浮点型张量,将True转换为1.0,将False转换为0.0。
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

       我们可以通过下面几个例子来测试`dropout_layer`函数。我们将输入`X`通过暂退法操作,暂退概率分别为0、0.5和1。

X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  2.,  0.,  6.,  0.,  0.,  0., 14.],
        [16., 18.,  0., 22.,  0., 26., 28., 30.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

2、定义模型参数

       同样,我们使用Softmax回归中引入的Fashion-MNIST数据集(不懂的可以看链接里面的文章)。我们定义具有两个隐藏层的多层感知机,每个隐藏层包含256个单元。

num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

3、定义模型

       我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后),并且可以为每一层分别设置暂退概率:常见的技巧是在靠近输入层的地方设置较低的暂退概率。下面的模型将第一个和第二个隐藏层的暂退概率分别设置为0.2和0.5,并且暂退法只在训练期间有效。

dropout1, dropout2 = 0.2, 0.5    # 在靠近输入层的地方设置较低的暂退概率,因此dropout1设为0.2

class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out


net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

4、训练和测试

       这类似于前面描述的多层感知机训练和测试。

num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

三、暂退法简洁实现

1、定义模型

       对于深度学习框架的高级API,我们只需在每个全连接层之后添加一个`Dropout`层,将暂退概率作为唯一的参数传递给它的构造函数。在训练时,`Dropout`层将根据指定的暂退概率随机丢弃上一层的输出(相当于下一层的输入)。在测试时,`Dropout`层仅传递数据。

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

2、训练和测试

       接下来,我们对模型进行训练和测试。

trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

四、总结

  • 暂退法在前向传播过程中,计算每一内部层的同时丢弃一些神经元。
  • 暂退法可以避免过拟合,它通常与控制权重向量的维数和大小结合使用的。
  • 暂退法将活性值$h$替换为具有期望值$h$的随机变量。
  • 暂退法仅在训练期间使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/252288.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【华为数据之道学习笔记】5-4 数据入湖方式

数据入湖遵循华为信息架构&#xff0c;以逻辑数据实体为粒度入湖&#xff0c;逻辑数据实体在首次入湖时应该考虑信息的完整性。原则上&#xff0c;一个逻辑数据实体的所有属性应该一次性进湖&#xff0c;避免一个逻辑实体多次入湖&#xff0c;增加入湖工作量。 数据入湖的方式…

如何从 iPhone 上恢复已删除的照片教程分享

您是否错误地删除了 iPhone 上的错误照片&#xff1f;或者您可能已将手机恢复出厂设置&#xff0c;但现在所有照片都消失了&#xff1f;如果您现在遇到这样的情况&#xff0c;我们可以为您提供解决方案。 在本文中&#xff0c;我们将向您展示七种数据恢复方法&#xff0c;可以…

对可恢复的情况使用受检异常

在Java中&#xff0c;受检异常&#xff08;Checked Exception&#xff09;通常用于表示程序能够预期并且可能进行恢复的异常情况。这类异常是在编译时由编译器强制进行处理的&#xff0c;使得程序员必须显式处理这些异常&#xff0c;或者在方法签名中使用 throws 关键字声明。 …

Zoho Mail企业邮箱:6大高效使用技巧

本期小Z就带您了解下&#xff0c;Zoho Mail非常实用的几个小技巧&#xff0c;帮助您进一步提升工作效率。 01/轻松从其他邮箱迁移到Zoho Mail 每天我们都会收到很多垃圾邮件、网络钓鱼、或者不需要的促销邮件等等&#xff0c;筛选出这些邮件耗时耗力&#xff0c;这个时候寻找…

借着期末作业,写一个JavaWeb项目

合集传送门 要求 学生成绩管理系统设计与实现 设计一个学生成绩管理系统。根据以下功能&#xff0c;分析使用的逻辑结构和存储结构。并设计菜单&#xff0c;显示相应结果。 &#xff08;1&#xff09;录入功能&#xff1a;能够录入学生成绩&#xff08;包括&#xff1a;学号…

面向对象方法分析之 各种图

没想到研究生了还要跟这些奇奇怪怪的图打交道&#xff08;不是&#xff09; 以下是面向对象分析中常用到的图。 静态视图 静态视图对应用领域中的概念以及与系统实现有关的内部概念建模&#xff0c;主要由类以及类之间的相互关系组成&#xff0c;在静态视图中不描述依赖于时…

《opencv实用探索·二十》点追踪技术

前言&#xff1a; 在学习点追踪技术前需要先了解下光流发追踪目标&#xff0c;可以看上一章内容&#xff1a;光流法检测运动目标 如果以光流的方式追踪目标&#xff0c;基本上我们可以通过goodFeaturesToTrack函数计算一系列特征点&#xff0c;然后通过Lucas-Kanade算法进行一…

vue3 setup语法糖写法基本教程

前言 官网地址&#xff1a;Vue.js - 渐进式 JavaScript 框架 | Vue.js (vuejs.org)下面只讲Vue3与Vue2有差异的地方&#xff0c;一些相同的地方我会忽略或者一笔带过与Vue3一同出来的还有Vite&#xff0c;但是现在不使用它&#xff0c;等以后会有单独的教程使用。目前仍旧使用v…

XXE漏洞 [NCTF2019]Fake XML cookbook1

打开题目 查看源代码 发现我们post传入的数据都被放到了doLogin.php下面 访问一下看看 提示加载外部xml实体 bp抓包一下看看 得到flag 或者这样 但是很明显这样是不行的&#xff0c;因为资源是在admin上&#xff0c;也就是用户名那里 PHP引用外部实体&#xff0c;常见的利用…

Explain工具-SQL性能优化

文章目录 SQL性能优化的目标Explain中type效率级别&#xff08;重要&#xff09;注意 Explain覆盖索引ExplainindexExplainfilesortExplainfilesort创建 idx_bd(b,d) SQL性能优化的目标 达到 range 级别 Explain中type效率级别&#xff08;重要&#xff09; 显示的是单位查询…

Ubuntu系统入门指南:基础操作和使用

Ubuntu系统的基础操作和使用 一、引言二、安装Ubuntu系统三、Ubuntu系统的基础操作3.1、界面介绍3.2、应用程序的安装和卸载3.3、文件管理3.4、系统设置 四、Ubuntu系统的日常使用4.1、使用软件中心4.2、浏览器的使用和网络连接设置4.3、邮件客户端的配置和使用4.4、文件备份和…

【上海大学数字逻辑实验报告】七、中规模元件及综合设计

一、实验目的 掌握中规模时序元件的测试。学会在Quartus II上设计序列发生器。 二、实验原理 74LS161是四位可预置数二进制加计数器&#xff0c;采用16引脚双列直插式封装的中规模集成电路&#xff0c;其外形如下图所示&#xff1a; 其各引脚功能为&#xff1a; 异步复位输…

我的世界指令大全

不知道我的世界指令大全&#xff1f;我的世界玩家我的世界指令大全&#xff0c;在下面的攻略中就为小伙伴们带来了我的世界指令大全。不知道我的世界指令大全的玩家们那就继续看下去这篇攻略吧。 1./summon&#xff1a;召唤实体。 2./give&#xff1a;给予玩家物品。 3./tp&a…

Epic 安装失败,错误代码SUPQR1612,必要的先决条件安装失败,弹窗CD-ROM

错误记录 并且弹出来这个窗口the feature you are trying to use is on a CD-ROM or other removable disk that is not available. 让我寻找这个msi文件 正在想办法解决 如果有人能解决 麻烦告知评论区&#xff01;

Nginx编译安装+Nginx模块详解+Nginx虚拟主机(新版)

Nginx编译安装Nginx模块详解Nginx虚拟主机 Nginx编译安装Nginx模块详解Nginx虚拟主机一、编译安装Nginx服务二、nginx版本升级1、nginx平滑升级的步骤2、示例 三、添加Nginx系统服务1、使用init.d脚本2、使用 systemd 服务配置 四、认识Nginx服务的主配置文件 nginx.conf1、全局…

为什么Apache Doris适合做大数据的复杂计算,MySQL不适合?

为什么Apache Doris适合做大数据的复杂计算&#xff0c;MySQL不适合&#xff1f; 一、背景说明二、DB架构差异三、数据结构差异四、存储结构差异五、总结 一、背景说明 经常有小伙伴发出这类直击灵魂的疑问&#xff1a; Q&#xff1a;“为什么Apache Doris适合做大数据的复杂计…

819. 最常见的单词

819. 最常见的单词 Java&#xff1a;split() 过滤 class Solution {public String mostCommonWord(String paragraph, String[] banned) {String s paragraph.replaceAll("\\p{Punct}", " "); // 去除所有标点符号String arr[] s.split(" "…

OpenCV开发:MacOS源码编译opencv,生成支持java、python、c++各版本依赖库

OpenCV&#xff08;Open Source Computer Vision Library&#xff09;是一个开源的计算机视觉和机器学习软件库。它为开发者提供了丰富的工具和函数&#xff0c;用于处理图像和视频数据&#xff0c;以及执行各种计算机视觉任务。 以下是 OpenCV 的一些主要特点和功能&#xff…

Apache Flume(3):数据持久化

1 使用组件 File Channel 2 属性设置 属性名默认值说明type-filecheckpointDir~/.flume/file-channel/checkpoint检查点文件存放路径dataDirs~/.flume/file-channel/data日志存储路径&#xff0c;多个路径使用逗号分隔. 使用不同的磁盘上的多个路径能提高file channel的性能 …

74hc244驱动数码管显示电路及程序

把七或八只发光二极管组合在一个模件上组成了个8字和小数点&#xff0c;用以显示数字。为了减少管脚&#xff0c;把各个发光管的其中同一个极接在一起作为共用点&#xff0c;因此就产生了共阳极和共阴极数码之说。共阳管就是把各个发光管的正极接在一起&#xff0c;而共阴管就刚…