【机器学习】037_暂退法

一、实现原理

具有输入噪音的训练,等价于Tikhonov正则化

核心方法:在前向传播的过程中,计算每一内部层的同时注入噪声

· 从作用上来看,表面上来说是在训练过程中丢弃一些神经元

· 假设x是某一层神经网络层的输出,是下一层的输入,我们希望对x加入一些噪音,使得:

E[x^`]=x

  ※x`的期望为x,也就是说平均上来说输出值还是x

· 暂退法对每个元素进行了如下扰动:

        有p的概率下取值:x^`_i=0

        其它情况(1-p概率):x^`_i = \frac{x_i}{1-p}

实践中使用暂退法:

· 通常将暂退法作用在全连接隐藏层的输出上

如图所示,在第一个隐藏层的输出上,有些神经元有p的概率使输出值置零。

非置零的输出值,即有1-p的概率被施加了一个较小的扰动值使其略微增大。

※暂退法只在训练中使用,dropout是正则项,在推理过程中不会使用,这样也会保证输出值确定

※每次执行暂退法的时候,实际上是每次随机采样了一些子神经网络

总结:

①暂退法将一些输出项随机置零来控制模型的复杂度

②暂退法的作用效果和正则化等价

③常应用在多层感知机的隐藏层输出上

④丢弃概率p是控制模型复杂度的超参数

二、代码实现

从零实现代码:

import torch
from torch import nn
from d2l import torch as d2l

def dropout_layer(X, dropout):
    # assert用于选择dropout符合范围的情况,不符合则报错
    assert 0 <= dropout <= 1, "不符合范围!"
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    # 在这一步操作中,首先定义一个和X张量形状相同但元素值均为随机数的张量
    # 将这个张量里每个元素与dropout比较,如果大于就置为True,小于等于就置为False
    # 再调用float将True和False转化为1和0
    # 这样,mask就是一个仅含1与0的张量了
    # 最后将mask里的每个元素与X里的每个元素做数乘
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

# 生成X来测试暂退法
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))

# 定义模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
# 定义模型
dropout1, dropout2 = 0.2, 0.5
# is_training用来表示当前是在测试还是在训练
class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
            # 输出不需要dropout作用
        out = self.lin3(H2)
        return out

net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

简洁实现代码:

import torch
from torch import nn
from d2l import torch as d2l

# 定义概率参数
dropout1, dropout2 = 0.2, 0.5

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

# 训练、测试模型
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/169773.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux进程通信——IPC、管道、FIFO的引入

进程间的通信——IPC 进程间通信 (IPC&#xff0c;InterProcess Communication) 是指在不同进程之间传播或交换信息。 IPC的方式通常有管道 (包括无名管道和命名管道) 、消息队列、信号量、共享存储、Socket、Streams等。其中 Socket和Streams支持不同主机上的两个进程IPC。 …

1230天,百度再见!!!

从2020年7月8日至2023年11月20日&#xff0c;在百度的工作到达了终点&#xff0c;完成了从学生向职场人的蜕变&#xff0c;是时候说再见了&#xff01; 一、成长收获 在这1230天里收获颇丰&#xff0c;下面与各位分享一下。 从技术至上到业务赋能的思想转变 相信很多人都存在“…

初始环境配置

目录 一、JDK1、简介2、配置步骤 二、Redis1、简介2、配置步骤 三、MySQL1、简介2、配置步骤 四、Git1、简介2、配置步骤 五、NodeJS1、简介2、配置步骤 六、Maven1、简介2、配置步骤 七、Tomcat1、简介2、配置步骤 一、JDK 1、简介 JDK 是 Oracle 提供的 Java 开发工具包&…

发币成功,记录一下~

N年前就听说了这样一种说法——“一个熟练的区块链工程师&#xff0c;10分钟就可以发出一个新的币” 以前仅仅是有这么一个认识&#xff0c;但当时并不特别关注这个领域。 最近系统性学习中&#xff0c;今天尝试发币成功啦&#xff0c;记录一下&#xff5e; 发在 Sepolia Tes…

EI论文程序:Adaboost-BP神经网络的回归预测算法,可作为深度学习对比预测模型,丰富实验内容,自带数据集,直接运行!

适用平台&#xff1a;Matlab 2021及以上 本程序参考中文EI期刊《基于Adaboost的BP神经网络改进算法在短期风速预测中的应用》&#xff0c;程序注释清晰&#xff0c;干货满满&#xff0c;下面对文章和程序做简要介绍。 为了提高短期风速预测的准确性&#xff0c;论文提出了使用…

【前端学java】java 中的数组(9)

往期回顾&#xff1a; 【前端学java】JAVA开发的依赖安装与环境配置 &#xff08;0&#xff09;【前端学 java】java的基础语法&#xff08;1&#xff09;【前端学java】JAVA中的packge与import&#xff08;2&#xff09;【前端学java】面向对象编程基础-类的使用 &#xff08…

猫12分类:使用多线程爬取图片的Python程序

本文目标 对于猫12目标检测部分的数据集&#xff0c;采用网络爬虫来制作数据集。 在网络爬虫中&#xff0c;经常需要下载大量的图片。为了提高下载效率&#xff0c;可以使用多线程来并发地下载图片。本文将介绍如何使用Python编写一个多线程爬虫程序&#xff0c;用于爬取图片…

代码随想录 Day50 单调栈 LeetCodeT503 下一个最大元素II T42接雨水

前言 前面我们说到了单调栈的第一题,下一个最大元素I,其实今天的两道题都是对他的变种,知道第一个单调栈的思想能够想清楚,其实这道题是很简单的 考虑好三个状态,大于等于小于,其实对于前面这些题目只要细心的小伙伴就会发现其实小于和等于的处理是一样的都是直接入栈,只有大于…

记录一次较为完整的Jenkins发布流程

文章目录 1. Jenkins安装1.1 Jenkins Docker安装1.2 Jenkins apt-get install安装 2. 关联github/gitee服务与webhook2.1 配置ssh2.2 Jenkins关联2.3 WebHook 3. 前后端关联发布 1. Jenkins安装 1.1 Jenkins Docker安装 Docker很好&#xff0c;但是我没有玩明白如何使用Docke…

【并发编程】Synchronized原理详解

&#x1f4eb;作者简介&#xff1a;小明java问道之路&#xff0c;2022年度博客之星全国TOP3&#xff0c;专注于后端、中间件、计算机底层、架构设计演进与稳定性建设优化&#xff0c;文章内容兼具广度、深度、大厂技术方案&#xff0c;对待技术喜欢推理加验证&#xff0c;就职于…

广州华锐互动VRAR | VR课件内容编辑器解决院校实践教学难题

VR课件内容编辑器由VR制作公司广州华锐互动开发&#xff0c;是一款专为虚拟现实教育领域设计的应用&#xff0c;它能够将传统的教学内容转化为沉浸式的三维体验。通过这款软件&#xff0c;教师可以轻松创建和编辑各种虚拟场景、模型和动画&#xff0c;以更生动、直观的方式展示…

.NET6使用MiniExcel根据数据源横向导出头部标题及数据

.NET6MiniExcel根据数据源横向导出头部标题 MiniExcel简单、高效避免OOM的.NET处理Excel查、写、填充数据工具。 特点: 低内存耗用&#xff0c;避免OOM、频繁 Full GC 情况 支持即时操作每行数据 兼具搭配 LINQ 延迟查询特性&#xff0c;能办到低消耗、快速分页等复杂查询 轻量…

CommonModule.dll动态链接库(DLL)文件丢失的处理方法

方法一、手动下载修复 (1)从网站下载commonmodule.dll文件到您的电脑上。 (2)将commonmodule.dll文件复制到" X:\Windows\system32 " (X代表您系统所在目录盘符&#xff0c;如&#xff1a;C:\Windows\system32)目录下。 (3)在开始菜单中找到"运行(R)" 或…

数据结构--字符串的模式匹配

案例导入概念 朴素&#xff08;暴力&#xff09;模式匹配算法 定位操作&#xff1a; 计算时间复杂度 总结

解决 Python requests 库中 SSL 错误转换为 Timeouts 问题

解决 Python requests 库中 SSL 错误转换为 Timeouts 问题&#xff1a;理解和处理 SSL 错误的关键 在使用Python的requests库进行HTTPS请求时&#xff0c;可能会遇到SSL错误&#xff0c;这些错误包括但不限于证书不匹配、SSL层出现问题等。如果在requests库中设置verifyFalse&…

ES6有何新特性?(下篇)

目录 函数参数的默认值设置 rest参数 扩展运算符 Symbol 迭代器 生成器 Promise Class 数值扩展 对象方法扩展 模块化 大家好呀&#xff01;今天这篇文章继续为大家介绍ES6的新特性&#xff0c;上上上篇文章介绍了一部分&#xff0c;这篇文章会将剩下的部分新增的特…

ElasticSearch在Windows上的下载与安装

Elasticsearch是一个开源的分布式搜索和分析引擎&#xff0c;它可以帮助我们快速地搜索、分析和处理大量数据。Elasticsearch能够快速地处理结构化和非结构化数据&#xff0c;支持全文检索、地理位置搜索、自动补全、聚合分析等功能&#xff0c;能够承载各种类型的应用&#xf…

用平板当电脑副屏(spacedesk)双端分享

文章目录 1.准备工作2.操作流程1. 打开spacedesk点击2. 勾选USB Cable Android3. 用数据线连接移动端和pc端&#xff0c;选择仅充电4. 打开安装好的spacedesk 记得在win系统中设置扩展显示器&#xff1a; 1.准备工作 下载软件spacedesk Driver Console pc端&#xff1a; 移动…

uniapp小程序定位;解决调试可以,发布不行的问题

遇见这个问题&#xff1b;一般情况就两种 1、域名配置问题&#xff1b; 2、隐私协议问题 当然&#xff0c;如果你的微信小程序定位接口没开启&#xff1b;定位也会有问题&#xff1b; 第一种&#xff0c;小程序一般是腾讯地图&#xff1b;所以一般都会用https://apis.map.qq.co…

Android studio run 手机或者模拟器安装失败,但是生成了debug.apk

错误信息如下&#xff1a;Error Installation did not succeed. The application could not be installed&#xff1a;List of apks 出现中文乱码&#xff1b; 我首先尝试了打包&#xff0c;能正常安装&#xff0c;再次尝试了debug的安装包&#xff0c;也正常安装&#xff1…