深度学习中的Dropout正则化:原理、代码实现与实际应用——pytorch框架下如何使用dropout正则化

目录

引言

一、导入包

二、dropout网络定义

三、创建模型,定义损失函数和优化器

四、加载数据

五、训练train

六、测试


引言

dropout正则化的原理相对简单但非常有效。它在训练神经网络时,以一定的概率(通常是在0.2到0.5之间)随机地将某些神经元的输出设置为零,即“关闭”这些神经元。这些“关闭”的神经元在整个训练过程中都不参与前向传播和反向传播。

这一过程有点类似于在每次训练迭代中从网络中删除一些神经元,然后在下一次迭代中再将它们添加回去。这种随机的“删除”和“添加”过程迫使网络不依赖于特定的神经元,从而提高了模型的泛化能力。

具体来说,dropout正则化有以下几个效果:

  1. 减少过拟合: 通过随机地关闭一些神经元,dropout可以减少神经网络对训练数据的过度拟合,使其更好地适应未见过的数据。

  2. 增加网络的鲁棒性: 由于每个神经元都有可能在任何时候被关闭,网络被迫学习对于任何输入都要保持稳健,而不是过于依赖某些特定的神经元。

  3. 防止协同适应: 在训练过程中,dropout可以防止神经元之间形成过于强烈的依赖关系,防止它们过分协同适应训练数据。

这一方法的关键在于,通过在训练期间随机地关闭一些神经元,模型不再依赖于任何一个特定的神经元,从而减少了过拟合的风险。在测试阶段,所有的神经元都被保留,但其输出值要按照训练时的概率进行缩放,以保持平衡。

一、导入包

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

二、dropout网络定义

# 定义一个带有 dropout 正则化的简单神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 128)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)  # 添加50%的dropout
        self.fc2 = nn.Linear(128, 64)
        self.dropout2 = nn.Dropout(0.3)  # 添加30%的dropout
        self.fc3 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout1(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

三、创建模型,定义损失函数和优化器

# 创建模型,定义损失函数和优化器
model = Net()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

四、加载数据

# 生成虚拟数据
torch.manual_seed(42)
X_train = torch.rand((1000, 10))
y_train = torch.randint(0, 2, (1000, 1)).float()
X_test = torch.rand((200, 10))
y_test = torch.randint(0, 2, (200, 1)).float()

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

五、训练train

# 训练循环
for epoch in range(10):
    for inputs, labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

六、测试

# 测试循环
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_dataloader:
        outputs = model(inputs)
        predicted = (outputs > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"测试准确率:{accuracy * 100:.2f}%")

然后我们就可以模仿上述的案例自己修改一些其中的参数进行网络训练和调整了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/185231.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

做亚马逊多久可以赚钱?做亚马逊需要多少资金?——站斧浏览器

做亚马逊需要时间、资金和全面的市场策略。创业者需要有耐心和决心,同时也要灵活应对市场变化。那么做亚马逊多久可以赚钱,做亚马逊需要多少资金。 做亚马逊多久可以赚钱 首先,就像任何其他生意一样,做亚马逊需要时间和努力来建立起稳定的客…

C#关键字、特性基础及扩展合集(持续更新)

一、基础 Ⅰ 关键字 1、record record(记录),编译器会在后台创建一个类。支持类似于结构的值定义,但被实现为一个类,方便创建不可变类型,成员在初始化后不能再被改变 (C#9新增) …

MarkDown学习

MarkDown学习 标题 三级标题 四级标题 字体 加粗(两侧加两个星号):Hello,World! 斜体(两侧加一个星号):Hello,World! 加粗加斜体(两侧加三个星号)&#xff1a…

可视化工作流管理流程及工具

Leangoo领歌是一款永久免费的专业的敏捷开发管理工具,提供端到端敏捷研发管理解决方案,涵盖敏捷需求管理、任务协同、进展跟踪、统计度量等。 Leangoo领歌上手快、实施成本低,可帮助企业快速落地敏捷,提质增效、缩短周期、加速创新…

基于51单片机电子钟闹钟LCD1602显示proteus仿真设计

基于51单片机的LCD1602电子钟闹钟proteus仿真设计 基于51单片机的LCD1602电子钟闹钟proteus仿真设计功能介绍:仿真图:原理图:设计报告:程序:器件清单:资料清单&&下载链接: 基于51单片机…

记一次简单的PHP反序列化字符串溢出

今天朋友给的一道题&#xff0c;让我看看&#xff0c;来源不知&#xff0c;随手记一下 <?php // where is flag error_reporting(0); class NFCTF{ public $ming,$id,$payload,$nothing;function __construct($iii){$this->ming$ii…

主播产品转场(款)话术

直播转场话术要点 在直播过程中&#xff0c;转场话术是非常重要的一部分。它可以帮助主播J顺利地将一个主题或场景过渡到另一个主题或场景&#xff0c;同时吸引观众的注意力。提高直播的观赏性和互动性。以下是一些直播转场话术的要点: 一、过渡性话语 过渡性话语是连接两个…

机器学习高级实践

&#x1f482; 个人网站:【 海拥】【神级代码资源网站】【办公神器】&#x1f91f; 基于Web端打造的&#xff1a;&#x1f449;轻量化工具创作平台&#x1f485; 想寻找共同学习交流的小伙伴&#xff0c;请点击【全栈技术交流群】 前言 在当今科技飞速发展的时代&#xff0c;机…

Linux开发工具(含gdb调试教程)

文章目录 Linux开发工具&#xff08;含gdb调试教程&#xff09;1、Linux 软件包管理器 yum2、Linux开发工具2.1、Linux编辑器 -- vim的使用2.1.1、vim的基本概念2.1.2、vim的基本操作2.1.3、vim正常模式命令集2.1.4、vim末行模式命令集 2.2、vim简单配置 3、Linux编译器 -- gcc…

TSINGSEE青犀智能分析网关道路积水识别AI算法方案

在各处的街道、路口等区域&#xff0c;及时发现道路积水问题&#xff0c;可以大大减少城市管理部门压力&#xff0c;及时处理&#xff0c;减少交通事故与人员摔倒事故。通过道路积水AI算法&#xff0c;能有效提高城市管理部门效率&#xff0c;优化城市管理方式。 那么&#xff…

西米支付:简单介绍一下支付公司的分账功能体系

随着互联网的普及和电子商务的快速发展&#xff0c;支付已经成为人们日常生活的重要组成部分。支付公司作为第三方支付平台&#xff0c;为消费者和商家提供了便捷、安全的支付方式。而在支付领域中&#xff0c;分账功能是一个非常重要的功能&#xff0c;它可以帮助企业实现资金…

livox 半固体激光雷达 gazebo 仿真 | 更换仿真中mid360雷达外形

livox 半固体激光雷达 gazebo 仿真 | 更换仿真中mid360雷达外形 livox 半固体激光雷达 gazebo 仿真 | 更换仿真中mid360雷达外形livox 介绍更换仿真中mid360雷达外形 livox 半固体激光雷达 gazebo 仿真 | 更换仿真中mid360雷达外形 livox 介绍 览沃科技有限公司&#xff08;L…

如何在Simulink中使用syms?换个思路解决报错:Function ‘syms‘ not supported for code generation.

问题描述 在Simulink中的User defined function使用syms函数&#xff0c;报错simulink无法使用外部函数。 具体来说&#xff1a; 我想在Predefined function定义如下符号函数作为输入信号&#xff0c;在后续模块传入函数参数赋值&#xff0c;以实现一次定义多次使用&#xf…

Pix2Pix 使用指南:从原理到项目应用

Pix2Pix Pix2Pix 介绍&#xff1a;使用条件 GAN 进行图像到图像的转换Pix2Pix 原理Pix2Pix 模型结构生成器&#xff1a;Unet结构判别器&#xff1a;PatchGAN目标函数目标函数总结 Pix2Pix 项目使用 Pix2Pix 介绍&#xff1a;使用条件 GAN 进行图像到图像的转换 Pix2Pix 论文&a…

预制菜产业发展背景下,如何利用视频监控保障行业监管工作

一、方案背景 随着社会的快速发展和人们生活水平的提高&#xff0c;预制菜产业作为现代餐饮行业的重要组成部分&#xff0c;越来越受到消费者的欢迎。然而&#xff0c;由于相关监管工作的不健全或不到位&#xff0c;一些问题也相继浮现出来&#xff0c;如&#xff1a;食品安全…

比较2个点的3种结构在不规则平面上的占比

2 2 2 1 2 2 2 2 2 1 2 2 2 2 2 1 2 2 3 3 3 x 3 3 2 2 2 1 2 2 2 2 2 1 2 2 在平面上有一个点x&#xff0c;再增加一个点,11的操作把平面分成了3部分2a1&#xff0c;2a2&#xff0c;2a3&#xff0c;3部分的比值是 2a1 2a2 2a3 5 25 …

2023年微软开源八个人工智能项目

自2001年软件巨头微软前首席执行官史蒂夫鲍尔默对开源&#xff08;尤其是Linux&#xff09;发表尖刻言论以来&#xff0c;微软正在开源方面取得了长足的进步。继ChatGPT于去年年底发布了后&#xff0c;微软的整个2023年&#xff0c;大多数技术都是面向开发人员和研究人员公开发…

NX二次开发UF_CSYS_set_origin 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CSYS_set_origin Defined in: uf_csys.h int UF_CSYS_set_origin(tag_t csys_tag, double origin [ 3 ] ) overview 概述 Set origin of coordinate system. Note that this fu…

TFA-Net

TFA SCA means ‘Self-Context Aggregation’ 作者未提供代码

leetcode:环形链表的入环点

题目描述 题目链接:力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 题目分析 我们假设起点到环的入口点的距离是L&#xff0c;入口点到相遇点的距离是X&#xff0c;环的长度是C 那么画图我们可以得知&#xff1a; 从开始到相遇时slow走的距离是LX从…