droppath

DropPath 是一种用于正则化深度学习模型的技术,它在训练过程中随机丢弃路径(或者说随机让某些部分的输出变为零),从而增强模型的鲁棒性和泛化能力。

代码解释:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
#drop_path(输入,将drop_prob初始化为0., 判断是否为训练模式)
    if drop_prob == 0. or not training:
        return x
#如果drop_prob等于0或者不是训练模式直接将输入输出
    keep_prob = 1 - drop_prob
#保留的概率
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) 
# 形状:(batch_size, 1, 1, ...)
# x.shape[0]获取xshape的第一维也就是batch_size
# (1,) * (x.ndim - 1) 将shape用1填充和x的形状一样
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
# torch.rand(shape, dtype=x.dtype, device=x.device)生成随机数(生成均值为0,标准差为1的正态# 分布随机数)形状和shape一致的也就是和x一致,数据类型,设备都和x一致
# 将随机数和keep_prob相加得到随机数(范围[keep_prob,1+keep_prob])
    random_tensor.floor_() 
# 二值化,生成 0 或 1 的 mask
# 也就是将随机数向下取整
    output = x.div(keep_prob) * random_tensor
#x.div(keep_prob)将输入张量x的所有值除以keep_prob,目的是 放大保留下来的部分

#* random_tensor根据0 或 1 的 mask决定哪些路径会被保留(1)或丢弃(0)
    return output

为什么要放大保留下来的部分:

  • 丢弃路径会导致部分值被置为零,模型整体输出的总期望值会下降。
  • 为了补偿这种下降,需要对保留下来的部分放大,使得丢弃路径后的总期望值和丢弃前一致。

因为只是补偿所以并不一定等与原期望

数学解释:

假设输入张量是 x=\begin{bmatrix} x_{1,}&x_{2,} & ... &, x_{n} \end{bmatrix},其中每个元素 xi表示特征。

期望:E=\frac{1}{n}\sum_{1}^{n}x_{i}

丢弃之后:E=\frac{1}{n}\sum_{1}^{n}{keepprob}\cdot x_{i}

放大之后:E=\frac{1}{n}\sum_{1}^{n}\frac{​{keepprob}\cdot x_{i}}{keepprob}=\frac{1}{n}\sum_{1}^{n}x_{i}

实例:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # 形状:(batch_size, 1, 1, ...)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # 二值化,生成 0 或 1 的 mask
    print(f'mask: {random_tensor}')
    output = x.div(keep_prob) * random_tensor
    return output

# 定义简单模型
class SimpleModel(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.linear = nn.Linear(4, 4)  # 简单的线性层
        self.drop_path = DropPath(drop_prob)  # 使用 DropPath
        self.activation = nn.ReLU()  # ReLU 激活

    def forward(self, x):
        print("输入数据:")
        print(x)

        x = self.linear(x)  # 线性层
        print("线性层输出:")
        print(x)

        x = self.activation(x)  # ReLU 激活
        print("激活后输出:")
        print(x)

        x = self.drop_path(x)  # DropPath
        print("DropPath 后输出:")
        print(x)

        return x

# 创建模型
model = SimpleModel(drop_prob=0.5)
model.train()  # 设置为训练模式以启用 DropPath

# 输入数据
input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                           [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32)

# 运行模型
output = model(input_data)

输出: 简单理解就是根据mask的1,0值对每个样本进行保留或置零

输入数据:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
线性层输出:
tensor([[ 1.2836, -1.4602,  2.2660, -1.7250],
        [ 1.3035, -4.1391,  4.5453, -2.5738]], grad_fn=<AddmmBackward0>)
激活后输出:
tensor([[1.2836, 0.0000, 2.2660, 0.0000],
        [1.3035, 0.0000, 4.5453, 0.0000]], grad_fn=<ReluBackward0>)
mask: tensor([[1.],
        [0.]])
DropPath 后输出:
tensor([[2.5672, 0.0000, 4.5321, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/927835.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

KAN-Transfomer——基于新型神经网络KAN的时间序列预测

1.数据集介绍 ETT(电变压器温度)&#xff1a;由两个小时级数据集&#xff08;ETTh&#xff09;和两个 15 分钟级数据集&#xff08;ETTm&#xff09;组成。它们中的每一个都包含 2016 年 7 月至 2018 年 7 月的七种石油和电力变压器的负载特征。 traffic(交通) &#xff1a;描…

03-12、SpringCloud Alibaba第十二章,升级篇,服务注册与配置中心Nacos

SpringCloud Alibaba第十二章&#xff0c;升级篇&#xff0c;服务注册与配置中心Nacos 一、为什么SpringCloud Alibaba 1、为什么 有了spring cloud这个微服务的框架&#xff0c;为什么又要使用spring cloud alibaba这个框架了&#xff1f;最重要的原因在于spring cloud中的…

算法之旅:LeetCode 拓扑排序由简入繁完全攻略

前言 欢迎来到我的算法探索博客&#xff0c;在这里&#xff0c;我将通过解析精选的LeetCode题目&#xff0c;与您分享深刻的解题思路、多元化的解决方案以及宝贵的实战经验&#xff0c;旨在帮助每一位读者提升编程技能&#xff0c;领略算法之美。 &#x1f449;更多高频有趣Lee…

MATLAB 离散点构建凸包,计算面积周长(88)

MATLAB 离散点构建凸包,计算面积周长(88) 一、算法介绍二、算法实现1.代码2.总结这是缘,亦是命中最美的相见!!! 一、算法介绍 给定一堆离散点云,构建二维凸包,并计算凸包的面积和周长。 凸包是由顺序顶点构成的,因此凸包也可以当作多边形,则例的面积和周长计算方法…

Matlab Simulink HDL Coder开发流程(一)— 创建HDL兼容的Simulink模型

创建HDL兼容的Simulink模型 一、使用Balnk DUT模板二、从HDL Coder库中选择模块三、为DUT开发算法/功能四、为设计创建Testbench五、仿真验证设计功能六、Simulink模型生成HDL代码 这个例子说明了如何创建一个用于生成HDL代码的Simulink模型。要创建兼容HDL代码生成的MATLAB算法…

【智商检测——DP】

题目 代码 #include <bits/stdc.h> using namespace std; const int N 1e510, M 110; int f[N][M]; int main() {int n, k;cin >> n >> k;for(int i 1; i < n; i){int x;cin >> x;f[i][0] __gcd(f[i-1][0], x);for(int j 1; j < min(i, k)…

神经网络入门实战:(九)分类问题 → 神经网络模型搭建模版和训练四步曲

(一) 神经网络模型搭建官方文档 每一层基本都有权重和偏置&#xff0c;可以仔细看官方文档。 pytorch 官网的库&#xff1a;torch.nn — PyTorch 2.5 documentation Containers库&#xff1a;用来搭建神经网络框架&#xff08;包含所有的神经网络的框架&#xff09;&#xff1b…

不同云计算网络安全等级

导读云计算的本质是服务&#xff0c;如果不能将计算资源规模化/大范围的进行共享&#xff0c;如果不能真正以服务的形式提供&#xff0c;就根本算不上云计算。 等级保护定级流程 定级是开展网络安全等级保护工作的 “基本出发点”&#xff0c;虚拟化技术使得传统的网络边界变…

langchain实现基于sql的问答

1. 数据准备 import requestsurl "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"response requests.get(url)if response.status_code 200:# Open a local file in binary write modewith open("Chinook.db", "wb&qu…

flink学习(14)—— 双流join

概述 Join:内连接 CoGroup&#xff1a;内连接&#xff0c;左连接&#xff0c;右连接 Interval Join&#xff1a;点对面 Join 1、Join 将有相同 Key 并且位于同一窗口中的两条流的元素进行关联。 2、Join 可以支持处理时间&#xff08;processing time&#xff09;和事件时…

深入学习指针(5)!!!!!!!!!!!!!!!

文章目录 1.回调函数是什么&#xff1f;2.qsort使用举例2.1使用qsort函数排序整形数据2.2使用sqort排序结构数据 3.qsort函数的模拟实现 1.回调函数是什么&#xff1f; 回调函数就是⼀个通过函数指针调⽤的函数。 如果你把函数的指针&#xff08;地址&#xff09;作为参数传递…

CEF127 编译指南 Linux篇 - 构建CEF Client(七)

1. 引言 在完成 CEF127 的编译工作后&#xff0c;我们需要了解如何正确运行编译后的程序。本文将详细介绍如何使用 CMake 构建示例程序&#xff0c;并成功运行 CEF 客户端。通过本文的指导&#xff0c;您将能够在 Linux 环境下顺利运行 CEF 应用程序。 2. 准备工作 2.1 确认…

位图的学习

一&#xff0c;位图介绍 位图&#xff08;Bitmap&#xff09;是一种用于存储图像的方式&#xff0c;它通过二维矩阵&#xff08;由像素组成&#xff09;来表示图像的每一个细节。每个像素通常对应一个特定的颜色值&#xff0c;位图的每个“位”就代表了图像的一个像素。 位图…

电脑与优傲协作机器人(实体)的TCP通讯(操作记录)

目录 一、UR通信端口 二、电脑&#xff08;客户端&#xff09;连接协作机器人&#xff08;服务端&#xff09; 1.设置网络方法 2.检查设置 3.示教器切换远程控制&#xff08;注&#xff09; 4.客户端与协作机器人建立连接 5.连接测试 三、电脑&#xff08;服务端&#…

后端 Java发送邮件 JavaMail 模版 20241128测试可用

配置授权码 依赖 <dependency><groupId>javax.mail</groupId><artifactId>javax.mail-api</artifactId><version>1.5.5</version> </dependency> <dependency><groupId>com.sun.mail</groupId><artifa…

12.2 正则表达式

object test04 {def main(args: Array[String]): Unit {//1.定义规则。写正则表达式val reg "\\d".r // \\d表示找数字//2.在目标字符串中&#xff0c;去按照这个规则去找符合的子字符串val result reg.findFirstIn("我是who&#xff0c;我的电话是&#xff…

MySQL:DDL数据定义语言

DDL(Data Definition Language)&#xff0c;数据定义语言 对数据库的常用操作 查看所有数据库 语法&#xff1a;show databases; 创建数据库 dbname&#xff1a;用户自己定义的数据库名称。 语法&#xff1a;create database [if not exists] dbname [charsetutf8]; 切换…

2024信创数据库TOP30之华为Gauss DB

近日&#xff0c;由DBC联合CIW/CIS共同发布的“2024信创数据库TOP30”榜单正式揭晓&#xff0c;汇聚了国内顶尖的数据库企业及其产品&#xff0c;成为展示中国信创领域技术实力与发展潜力的重要平台。在这份榜单中&#xff0c;华为的GaussDB凭借其卓越的技术实力、广泛的行业应…

HTML+CSS+JS制作圣诞祝福网页教程(附源码)

简介 在这个教程中&#xff0c;我们将学习如何使用HTML、CSS和JavaScript来创建一个充满节日气氛的圣诞祝福网页。这个网页将包括一个动态的圣诞树、飘落的雪花和闪烁的装饰物&#xff0c;以及一个显示“圣诞快乐&#xff01;”的消息。 准备工作 在开始之前&#xff0c;请确…

华为仓颉编程环境搭建

1、仓颉介绍 摘自华为官方&#xff1a;仓颉编程语言作为一款面向全场景应用开发的现代编程语言&#xff0c;通过现代语言特性的集成、全方位的编译优化和运行时实现、以及开箱即用的 IDE 工具链支持&#xff0c;为开发者打造友好开发体验和卓越程序性能。 其具体特性表现为&am…