抑制过拟合——Dropout原理

抑制过拟合——Dropout原理

  • Dropout的工作原理
  • 实验观察

  在机器学习领域,尤其是当我们处理复杂的模型和有限的训练样本时,一个常见的问题是过拟合。简而言之,过拟合发生在模型对训练数据学得太好,以至于它捕捉到了数据中的噪声和误差,而不仅仅是底层模式。具体来说,这在神经网络训练中尤为常见,表现为在训练数据上表现优异(例如损失函数值很小,预测准确率高)而在未见过的数据(测试集)上表现不佳。

  过拟合不仅是机器学习新手容易遇到的问题,即使是经验丰富的从业者也会面临这一挑战。一个典型的解决方案是采用模型集成技术,这涉及训练多个模型并将它们的预测结合起来。但这种方法的缺点是显而易见的:它既耗时又昂贵,不仅在训练阶段,而且在模型评估和部署时也是如此。

  在这种背景下,Dropout 作为一种有效的正则化技术,可以显著减轻过拟合问题。它的基本原理是在每次训练迭代中随机“丢弃”(即暂时移除)网络中的一部分神经元。这种方法不仅简单,而且被证明在许多情况下都非常有效。

Dropout的工作原理

  在 PyTorch 中,Dropout 层的使用相当直观。通常,它被添加到神经网络的各个层之间,如下所示:

torch.nn.Dropout(p=0.5, inplace=False)

  p:这是一个关键参数,代表着每个神经元被丢弃的概率。

  在实践中,这意味着对于网络中的每个神经元,它在每次训练迭代中都有 1 − p 1-p 1p 的概率被保留, p p p 的概率被丢弃。值得注意的是,这种随机性确保了每个mini-batch都在对不完全相同的网络进行训练,从而减少过拟合的风险。

  在训练期间,对于每个训练样本,网络中的每个神经元都有概率 1 − p 1-p 1p 被保留,概率 p p p 被丢弃。如果神经元被保留,则其输出乘以 1 1 − p \frac{1}{1-p} 1p1​(这样做是为了保持该层输出的总期望值不变)。设 r j r_j rj​ 为一个随机变量,它对应于第 j j j 个神经元,且服从伯努利分布(即 r j = 1 r_j = 1 rj=1 的概率为 1 − p 1-p 1p r j = 0 r_j = 0 rj=0 的概率为 p p p)。那么在训练时,神经元的输出 y j y_j yj变为 r j × y j / ( 1 − p ) r_j \times y_j / (1-p) rj×yj/(1p)

为什么需要保持期望不变? 举个简单的例子,假设某层有两个神经元,它们的输出在没有dropout时都是1。在应用了50%的dropout后,期望只有一个神经元被激活,输出为1,另一个被丢弃,输出为0。这样,这层的平均输出变成了0.5。为了保持输出的总期望值不变,激活的神经元的输出应该乘以2,即 1 1 − p \frac{1}{1-p} 1p1​,这样平均输出才能保持为1,与没有应用dropout时相同。这样的处理有助于保持整个网络的稳定性和一致性。

  在模型预测(或测试)阶段,所有的神经元都保持激活(即不进行dropout)。因为在训练阶段,神经元的输出已经被放大了 1 1 − p \frac{1}{1-p} 1p1 倍,所以在预测时不需要进行任何调整,直接使用网络进行前向传播即可。

在这里插入图片描述

实验观察

  为了更深入地理解 Dropout 的影响,我们可以通过一个实验来观察不同的 Dropout 设置对训练过程的影响。比如,可以比较 Dropout = 0.1Dropout = 0 在训练过程中的表现差异,相关代码实现如下:

import torch
from tensorboardX import SummaryWriter
from torch import optim, nn
import time


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linears = nn.Sequential(
            nn.Linear(2, 20),

            nn.Linear(20, 20),
            nn.Dropout(0.1),

            nn.Linear(20, 20),

            nn.Linear(20, 20),

            nn.Linear(20, 1),
        )

    def forward(self, x):
        _ = self.linears(x)
        return _

lr = 0.01
iteration = 1000


x1 = torch.arange(-10, 10).float()
x2 = torch.arange(0, 20).float()
x = torch.cat((x1.unsqueeze(1), x2.unsqueeze(1)), dim=1)
y = 2*x1 - x2**2 + 1

model = Model()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
loss_function = torch.nn.MSELoss()

start_time = time.time()
writer = SummaryWriter(comment='_随机失活')

for iter in range(iteration):
    y_pred = model(x)
    loss = loss_function(y, y_pred.squeeze())
    loss.backward()

    for name, layer in model.named_parameters():
        writer.add_histogram(name + '_grad', layer.grad, iter)
        writer.add_histogram(name + '_data', layer, iter)
    writer.add_scalar('loss', loss, iter)

    optimizer.step()
    optimizer.zero_grad()

    if iter % 50 == 0:
        print("iter: ", iter)

print("Time: ", time.time() - start_time)

这里我们使用 TensorBoardX 进行结果的可视化展示。

  通过观察模型训练1000轮后的线性层梯度分布,可以发现,应用 Dropout 后的模型梯度通常会更加分散和多样化。这种梯度的多样性有助于防止模型过于依赖训练数据中的特定模式,从而减轻过拟合。

在这里插入图片描述

  同样值得注意的是,模型的损失曲线也会受到影响。加入 Dropout 通常会使损失曲线出现更多的波动(例如,图中的蓝色曲线),这反映了模型在学习过程中的不稳定性。然而,这种不稳定性通常是可接受的,因为它反映了模型正在学习更多的泛化模式而不是简单地记住训练数据。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/205405.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智能优化算法应用:基于旗鱼算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于旗鱼算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于旗鱼算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.旗鱼算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…

收藏!7个小众宝藏的开发者学习类网站

1、simplilearn 地址:https://www.simplilearn.com/ simplilearn是全球排名第一的在线学习网站,它的课程由世界知名大学、顶级企业和领先的行业机构通过实时在线课程设计和提供,其中包括顶级行业从业者、广受欢迎的培训师和全球领导者。 2、…

【代码随想录】算法训练计划37

贪心 1、738. 单调递增的数字 题目&#xff1a; 输入: n 10 输出: 9 思路&#xff1a; func monotoneIncreasingDigits(n int) int {// 贪心&#xff0c;利用字符数组s : strconv.Itoa(n)ss : []byte(s)leng : len(ss)if leng < 1 {return n}for i:leng-1; i>0; i-- …

西南科技大学数字电子技术实验一(数字信号基本参数与逻辑门电路功能测试及FPGA 实现)FPGA部分

一、 实验目的 1、掌握基于 Verilog 语言的 diamond 工具设计全流程。 2、熟悉、应用 Verilog HDL 描述数字电路。 3、掌握 Verilog HDL 的组合和时序逻辑电路的设计方法。 4、掌握“小脚丫”开发板的使用方法。 二、 实验原理 与门逻辑表达式:Y=AB 原理仿真图: 2 输入…

springboot+jsp+java人才招聘网站4f21r

本基于springboot的人才招聘网站主要满足3种类型用户的需求&#xff0c;这3种类型用户分别为求职者、企业和管理员&#xff0c;他们分别实现的功能如下。 &#xff08;1&#xff09;求职者进入网站后可查看职位信息、企业信息以及职位新闻等&#xff0c;注册登录后可实现申请职…

C#测试开源运行耗时库MethodTimer.Fody

微信公众号“dotNET跨平台”的文章《一个监控C#方法运行耗时开源库》介绍了支持测量方法耗时的包MethodTimer.Fody&#xff0c;使用方便&#xff0c;还可以自定义输出信息格式。本文学习并测试MethodTimer.Fody包的使用方式。   新建控制台程序&#xff0c;通过Nuget包管理器…

java学校高校运动会报名信息管理系统springboot+jsp

课题研究方案&#xff1a; 结合用户的使用需求&#xff0c;本系统采用运用较为广泛的Java语言&#xff0c;springboot框架&#xff0c;HTML语言等关键技术&#xff0c;并在idea开发平台上设计与研发创业学院运动会管理系统。同时&#xff0c;使用MySQL数据库&#xff0c;设计实…

git-5

1.GitHub为什么会火&#xff1f; 2.GitHub都有哪些核心功能&#xff1f; 3.怎么快速淘到感兴趣的开源项目 github上面开源项目非常多&#xff0c;为了我们高效率的找到我们想要的资源 根据时间 不进行登录&#xff0c;是没有办法享受到高级搜索中的代码功能的&#xff0c;登录…

MacOS + Android Studio 通过 USB 数据线真机调试

环境&#xff1a;Apple M1 MacOS Sonoma 14.1.1 软件&#xff1a;Android Studio Giraffe | 2022.3.1 Patch 3 设备&#xff1a;小米10 Android 13 一、创建测试项目 安卓 HelloWorld 项目: 安卓 HelloWorld 项目 二、数据线连接手机 1. 手机开启开发者模式 参考&#xff1…

【Qt绘图】之绘制坦克

使用绘图事件&#xff0c;绘制坦克。 效果 效果很逼真&#xff0c;想象力&#xff0c;有没有。 示例 代码像诗一样优雅&#xff0c;有没有。 包含头文件 #include <QApplication> #include <QWidget> #include <QPainter>绘制坦克类 class TankWidge…

揭示堆叠自动编码器的强大功能 - 最新深度学习技术

简介 在不断发展的人工智能和机器学习领域&#xff0c;深度学习技术由于其处理复杂和高维数据的能力而获得了巨大的普及。在各种深度学习模型中&#xff0c;堆叠自动编码器[1]作为一种多功能且强大的工具脱颖而出&#xff0c;用于特征学习、降维和数据表示。本文探讨了堆叠式自…

【JavaEE】多线程 -- 死锁问题

目录 1. 问题引入 2.死锁问题的概念和原因 3. 解决死锁问题 1. 问题引入 在学习死锁之前, 我们先观察下面的代码能否输出正确的结果: 运行程序, 能正常输出结果: 这个代码只管上看起来, 好像是有锁冲突的, 此时的 locker 对象已经是加锁的状态, 在尝试对 locker 加锁, 不应该…

鸿蒙HarmonyOS应用开发-ColumnRow组件

1 概述 一个丰富的页面需要很多组件组成&#xff0c;那么&#xff0c;我们如何才能让这些组件有条不紊地在页面上布局呢&#xff1f;这就需要借助容器组件来实现。 容器组件是一种比较特殊的组件&#xff0c;它可以包含其他的组件&#xff0c;而且按照一定的规律布局&#xf…

【docker系列】docker实战之部署SpringBoot项目

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

js提取iconfont项目的图标

iconfont 可以让我们轻松使用字体图标&#xff0c;比如使用 iconfont 提供的 js&#xff0c;就可以愉快的码代码了。 //at.alicdn.com/t/c/font_xxxxx.js通常公司会有提供一套图标供所有系统使用&#xff0c;比如图标库里有 1000 个图标&#xff0c;但某个项目只需要使用 10 个…

南大通用 GBase 8s数据库级别权限

对于所有有权使用指定数据库的用户都必须赋予其数据库级别的用户权限。在GBase 8s 中&#xff0c;数据库级别的用户权限有三种&#xff0c;按权限从低到高排列依次为&#xff1a;CONNECT、RESOURCE、DBA。 1. CONNECT 这是级别最低的一种数据库级别用户权限。拥有该权限的用户…

Windows Terminal CMD 终端配置方案: 不只是酷炫外观

大一的时候小学期我们还是用 Windows cmd 终端写的订餐系统&#xff0c;尽管进我们所能地改了改配色&#xff0c;成品还是让人不忍直视。 当时学习遇到的大多数运行需求可以通过 IDE 解决&#xff0c;再加上 CMD 丑成这样&#xff0c;挺让人抵触的。 后来对命令行操作的学习需…

【数据清洗 | 数据规约】数据类别型数据 编码最佳实践,确定不来看看?

&#x1f935;‍♂️ 个人主页: AI_magician &#x1f4e1;主页地址&#xff1a; 作者简介&#xff1a;CSDN内容合伙人&#xff0c;全栈领域优质创作者。 &#x1f468;‍&#x1f4bb;景愿&#xff1a;旨在于能和更多的热爱计算机的伙伴一起成长&#xff01;&#xff01;&…

回文链表,剑指offer 27,力扣 61

目录 题目&#xff1a; 我们直接看题解吧&#xff1a; 解题方法&#xff1a; 难度分析&#xff1a; 审题目事例提示&#xff1a; 解题分析&#xff1a; 解题思路&#xff08;数组列表双指针&#xff09;&#xff1a; 代码说明补充&#xff1a; 代码实现&#xff1a; 代码实现&a…

Pix2Pix 使用指南:一副图像到另一副图像的转换

Pix2Pix Pix2Pix 介绍&#xff1a;使用条件 GAN 进行图像到图像的转换Pix2Pix 原理Pix2Pix 模型结构生成器&#xff1a;Unet结构判别器&#xff1a;PatchGAN目标函数目标函数总结 Pix2Pix 项目使用 Pix2Pix 介绍&#xff1a;使用条件 GAN 进行图像到图像的转换 Pix2Pix 论文&a…