【深度学习】Dropout、DropPath

一、Dropout

1. 概念

Dropout 在训练阶段会让当前层每个神经元以drop_prob( 0 ≤ drop_prob ≤ 1 0\leq\text{drop\_prob}\leq1 0drop_prob1)的概率失活并停止工作,效果如下图。

测试阶段不会进行Dropout。由于不同批次、不同样本的神经元失活情况不同,测试时枚举所有情况进行推理是不现实的,所以原文使用一种均值近似的方法进行逼近。详情如下图:

如图, w \bold{w} w为一个神经元后的权重。假设该神经元的输出均值为 μ \mu μ,若训练阶段该神经元的存活概率为 p p p,则Dropout使其输出均值变为 p × μ p\times\mu p×μ,为使测试时该神经元输出逼近训练输出,测试阶段该神经元输出会被乘上 p p p以使测试与训练输出均值相同。

简单来说,训练时Dropout按照概率drop_prob使神经元停止工作,测试时所有神经元正常工作,但其输出值要乘上1-drop_prob( p = 1 − drop_prob p=1-\text{drop\_prob} p=1drop_prob)。

不过,我们希望测试代码执行效率尽可能高,即便仅增加一个概率计算也不是我们希望的。所以实际计算时,会在训练阶段给神经元乘上一个缩放因子 1 p \frac{1}{p} p1。这样,训练输出的均值仍为 μ \mu μ,测试则不进行Dropout也不再乘上 p p p而是原样输出。

2. 功能

优势:
Dropout能够提高网络的泛化能力,防止过拟合。解释如下:
(1) 训练阶段每个神经元是相互独立的,仅drop_prob相同,即使是同一批次不同样本失活的神经元也是不同的。所以原文作者将Dropout的操作视为多种模型结构下结果的集成,由于集成方法能够避免过拟合,因此Dropout也能达到同样的效果。
(2) 减少神经元之间的协同性。有些神经元可能会建立与其它节点的固定联系,通过Dropout强迫神经元和随机挑选出来的其它神经元共同工作,减弱了神经元节点间的联合适应性,增强了泛化能力。
劣势:
(1) Dropout减缓了收敛的速度。训练时需要通过伯努利分布生成是否drop每一个神经元的情况,额外的乘法和缩放运算也会增加时间。
(2) Dropout一般用于全连接层,卷积层一般使用BatchNorm来防止过拟合。Dropout与BatchNorm不易兼容,Dropout导致训练过程中每一层输出的方差发生偏移,使得BatchNorm层统计的方差不准确,影响BatchNorm的正常使用。

3. 实现

import torch.nn as nn
import torch


class dropout(nn.Module):
    def __init__(self, drop_prob):
        super(dropout, self).__init__()
        assert 0 <= drop_prob <= 1, 'drop_prob should be [0, 1]'
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.training:
            keep_prob = 1 - self.drop_prob
            mask = keep_prob + torch.rand(x.shape)
            mask.floor_()
            return x.div(keep_prob) * mask
        else:
            return x


if __name__ == '__main__':
    x = torch.randn((8, 768))  # [batch_size, feat_dim],dropout常在全连接层之后,所以我们以一维数据为例
    drop = dropout(0.1)
    my_o = drop(x)

二、DropPath

1. 概念

DropPath 在训练阶段将深度学习网络中的多分支结构随机删除,效果如下图:

上图是ViT中的一个模块,多分支体现在ResNet结构的引入。可以看出,DropPath在多分支中起作用对位置有明确的要求,需要放在分支合并之前。此外,DropPath也需要对训练输出进行缩放(乘 1 1 − drop_prob \frac{1}{1-\text{drop\_prob}} 1drop_prob1)以确保测试输出结果的有效性和计算的高效性,这样在测试阶段就不会进行DropPath。

事实上,DropPath功能的实现是按照drop_prob概率将该分支的当前输出全部置0。具体来说,对于某个含有DropPath的分支,该分支输出的一个批次的每个样本都独立的按照drop_prob概率被完全置0或完整保留。

2. 功能

一般可以作为正则化手段加入网络防止过拟合,但会增加网络训练的难度。如果设置的drop_prob过高,模型甚至有可能不收敛。

3. 实现

import torch
import torch.nn as nn


class DropPath(nn.Module):
    """
    随机丢弃该分支上的每个样本
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob

        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # (batch_size, 1, 1, 1)维数与输入保持一致,仅需要batch_size个值
        mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        mask.floor_()  # 二值化,向下取整用于确定保存哪些样本

        output = x.div(keep_prob) * mask
        return output


if __name__ == "__main__":
    x = torch.randn((8, 197, 768))  # [batch_size, num_token, token_dim]
    drop_path = DropPath(drop_prob=0.5)
    my_o = drop_path(x)

致谢:

本博客仅做记录使用,无任何商业用途,参考内容如下:
【个人理解向】Dropout和Droppath原理及源码讲解
nn.Dropout、DropPath的理解与pytorch代码
Drop系列正则化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562384.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IMUGNSS的误差状态卡尔曼滤波器(ESKF)---更新过程

IMU&GNSS的误差状态卡尔曼滤波器&#xff08;ESKF&#xff09;---更新过程 ESKF的更新过程 ESKF的更新过程 前面介绍的是ESKF的运动过程&#xff0c;现在考虑更新过程。假设一个抽象的传感器能够对状态变量产生观测&#xff0c;其观测方程为抽象的h,那么可以写为 其中z为…

创新指南|节日期间提高销量的 10 个最佳技巧

许多网上购物者在感恩节前开始假日购物。假期是在线企业销售产品和增加销售额的最佳时机。根据万事达卡的数据&#xff0c;去年在线假日销售额增长了 10.6%&#xff0c;而店内销售额增长了 6.8%。此外&#xff0c;2023年美国消费者平均计划在假日旺季花费约1,530美元。在线企业…

存储过程的查询

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 在实际使用中&#xff0c;经常会需要查询数据库中已有的存储过程或者某一个存储过程的内容&#xff0c; 下面就介绍-下如何查询存储过程。 这需要使用到数据字典 user_sou…

vscode 配置verilog环境

一、常用的设置 1、语言设置 安装如下插件&#xff0c;然后在config 2、编码格式设置 解决中文注释乱码问题。vivado 默认是这个格式&#xff0c;这里也设置一样。 ctrl shift p 打开设置项 3、插件信任区设 打开一个verilog 文件&#xff0c;显示是纯本文&#xff0c;没…

B树和B+树试题解析

一、单项选择题 01&#xff0e;下图所示是一棵&#xff08;A ). A.4阶B树 B.3阶B树 C.4阶B树 D.无法确定 02.下列关于m阶B树的说法中&#xff0c;错误的是( C ). A.根结点至多有m棵子树 B.所有叶结点都在同一层次上 C.非叶结点至…

算法入门——二分查找

目录 1、二分模板 2、习题 1.704.二分查找 2.35.搜索插入位置 3.744. 寻找比目标字母大的最小字母 4.69. x 的平方根 5.1351. 统计有序矩阵中的负数 6.74. 搜索二维矩阵 7.34. 在排序数组中查找元素的第一个和最后一个位置 8.33. 搜索旋转排序数组 9.153. 寻找旋转排…

【GoWeb框架初探————XORM篇】

1. XORM xorm 是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 1.1 特性 支持 Struct 和数据库表之间的灵活映射&#xff0c;并支持自动同步事务支持同时支持原始SQL语句和ORM操作的混合执行使用连写来简化调用支持使用ID, In, Where, Limit, Join, Havi…

java学习笔记2

3 选择结构 3.1 if选择结构 3.1.1 基本if结构 语法if(条件){// 代码块 }执行流程 当if条件为真,执行代码块,否则不执行代码块。 代码 public class Demo1 {public static void main(String[] args) {// 需求: 张浩的考试成绩>90分,奖励一部Iphone6sScanner sc = new S…

mapreduce中的ReduceTask工作机制(Hadoop)

ReduceTask 是 Hadoop 中的一个重要组件&#xff0c;负责对 MapTask 的输出进行合并、排序和归并&#xff0c;最终生成最终的输出结果。 ReduceTask 的工作机制 1. 分组&#xff08;Shuffle&#xff09;阶段&#xff1a; 在分组阶段&#xff0c;ReduceTask 会从多个 Mapper …

第二届 Oceanbase 开发者大会 实录

第二届 Oceanbase 开发者大会 实录 今天很有幸参加了Oceanbase 开发者大会&#xff0c;我是真的我一开始还不知道什么是Oceanbase &#xff0c;直到我开了会才知道。看来真的需要多参加一些这样活动。 会议议程 我们科普一下什么是Oceanbase OceanBase 是阿里巴巴集团推出…

FastChat启动与部署通义千问大模型

FastChat简介 FastChat is an open platform for training, serving, and evaluating large language model based chatbots. FastChat powers Chatbot Arena, serving over 10 million chat requests for 70 LLMs.Chatbot Arena has collected over 500K human votes from sid…

Llama 3 实测效果炸裂,一秒写数百字(附镜像站)

这几天大火的llama 3刚刚在https://askmanyai.cn上线了&#xff01; 玩了一会儿&#xff0c;这个生成速度是真的亚麻呆住。文案写作和代码生成直接爽到起飞&#xff0c;以往gpt要写一两分钟的千字文&#xff0c;llama 3几秒钟就写完了。而且效果甚至感觉更好&#xff1f; 效果惊…

日期相关的题目

日期相关的题目 1. 计算日期到天数转换2. 日期累加3. 打印日期4. 日期差值 1. 计算日期到天数转换 输出示例: 思路&#xff1a;计算前n-1个月的天数在加上这个月的天数。 #include <iostream> using namespace std;int main() {int year, month, day;cin >> yea…

数据结构练习-数据结构概述

----------------------------------------------------------------------------------------------------------------------------- 1. 在数据结构中&#xff0c;从逻辑上可以把数据结构分成( )。 A. 动态结构和静态结构 B. 紧凑结构和非紧凑结构 C. 线性结…

Spring AI Summary

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl Spring AI is a project that aims to streamline the development of AI applications by providing abstractions and reusable components that can be easily integrate…

梯度消失/梯度爆炸

梯度消失/梯度爆炸&#xff08;Vanishing / Exploding gradients&#xff09; 梯度消失或梯度爆炸&#xff1a;训练神经网络的时候&#xff0c;导数或坡度有时会变得非常大&#xff0c;或者非常小&#xff0c;甚至于以指数方式变小&#xff0c;这加大了训练的难度。 g ( z ) …

Java学习Go(入门)

下载Go 《官网下载golang》 直接点Download&#xff0c;然后根据你自己的操作系统进行下载&#xff0c;我这里以win10为例 安装go 默认安装到C:\Program Files\Go&#xff0c;这里我们可以选择安装到其他盘&#xff0c;也可以选择默认安装。初学者建议直接一路next。 安装完…

Java发送邮件 启用SSL

使用的maven依赖: <dependency><groupId>com.sun.mail</groupId><artifactId>javax.mail</artifactId><version>1.4.7</version> </dependency> 配置文件mail.properties如下: # 邮箱配置 email.username=your-email@exa…

(助力国赛)美赛O奖数学建模可视化!!!含代码2(箱型图、旭日图、直方图、三元图、平行坐标图、密度图、局部放大图)

众所周知&#xff0c;数学建模的过程中&#xff0c;将复杂的数据和模型结果通过可视化图形呈现出来&#xff0c;不仅能够帮助我们更深入地理解问题&#xff0c;还能够有效地向评委展示我们的研究成果。   今天&#xff0c;承接《可视化代码1》&#xff0c;作者将与大家分享《…

【软考---系统架构设计师】软件架构

目录 1 一、软件架构的概念 二、软件架构风格 &#xff08;1&#xff09;数据流风格​​​​​​​ &#xff08;2&#xff09;调用/返回风格 &#xff08;3&#xff09;独立构件风格 &#xff08;4&#xff09;虚拟机风格 &#xff08;5&#xff09;仓库风格 三、架构…