【PyTorch】 暂退法(dropout)

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现
      • 2.2.1. 主要代码
      • 2.2.2. 完整代码
      • 2.2.3. 输出结果

1. 理论介绍

  • 线性模型泛化的可靠性是有代价的,因为线性模型没有考虑到特征之间的交互作用,由此模型灵活性受限。
  • 泛化性和灵活性之间的基本权衡被描述为偏差-方差权衡
    • 线性模型有很高的偏差,因此它们只能表示一小类函数,但其方差很低,因此它们在不同的随机数据样本上可以得出相似的结果。
    • 神经网络不局限于单独查看每个特征,而是学习特征之间的交互,但即使我们有比特征多得多的样本,深度神经网络也有可能过拟合。
  • 经典泛化理论认为,为了缩小训练和测试性能之间的差距,应该以简单的模型为目标。 简单性以较小维度的形式展现,简单性的另一个角度是平滑性,即函数不应该对其输入的微小变化敏感
  • 暂退法在前向传播过程中,计算每一内部层的同时注入噪声。 因为当训练一个有多层的深层网络时,注入噪声只会在输入-输出映射上增强平滑性。从表面上看是在训练过程中丢弃(drop out)一些神经元。 在整个训练过程的每一次迭代中,标准暂退法包括在计算下一层之前将当前层中的一些节点置零。
  • 神经网络过拟合与每一层都依赖于前一层激活值相关,这种情况称为共适应性,而暂退法会破坏共适应性。
  • 可以以一种无偏向的方式注入噪声,在固定住其他层时,每一层的期望值等于没有噪音时的值。
  • 标准暂退正则化通过按未丢弃的节点的分数进行规范化来消除每一层的偏差,换言之,每个中间活性值 h h h以暂退概率 p p p由随机变量 h ′ h' h替换,即:
    h ′ = { 0  概率为  p h 1 − p  其他情况 \begin{aligned} h' = \begin{cases} 0 & \text{ 概率为 } p \\ \frac{h}{1-p} & \text{ 其他情况} \end{cases} \end{aligned} h={01ph 概率为 p 其他情况
  • 通常,我们在测试时不用暂退法。 给定一个训练好的模型和一个新的样本,我们不会丢弃任何节点,因此不需要标准化。 然而也有一些例外:一些研究人员在测试时使用暂退法, 用于估计神经网络预测的不确定性: 如果通过许多不同的暂退法遮盖后得到的预测结果都是一致的,那么我们可以说网络发挥更稳定。
  • 我们可以将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率,常见的技巧是在靠近输入层的地方设置较低的暂退概率。

2. 实例解析

2.1. 实例描述

使用具有两个隐藏层的多层感知机和暂退法,拟合Fashion-MNIST数据集。

2.2. 代码实现

2.2.1. 主要代码

net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(0.2),	# 暂退概率为0.2
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Dropout(0.5),	# 暂退概率为0.5
        nn.Linear(256, 10)
).to(device)

2.2.2. 完整代码

import os
from tensorboardX import SummaryWriter
from rich.progress import track
from torchvision.transforms import Compose, ToTensor
from torchvision.datasets import FashionMNIST
import torch
from torch.utils.data import DataLoader
from torch import nn, optim

def load_dataset():
    """加载数据集"""
    root = "./dataset"
    transform = Compose([ToTensor()])
    mnist_train = FashionMNIST(root, True, transform, download=True)
    mnist_test = FashionMNIST(root, False, transform, download=True)

    dataloader_train = DataLoader(mnist_train, batch_size, shuffle=True, 
        num_workers=num_workers,
    )
    dataloader_test = DataLoader(mnist_test, batch_size, shuffle=False,
        num_workers=num_workers,
    )
    return dataloader_train, dataloader_test


if __name__ == '__main__':
    # 全局参数设置
    num_epochs = 10
    batch_size = 256
    num_workers = 3
    device = torch.device('cuda:0')
    lr = 0.5

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir=log_dir())

    # 数据集配置
    dataloader_train, dataloader_test = load_dataset()

    # 定义模型
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(256, 10)
    ).to(device)
    def init_weights(m):
        if type(m) == nn.Linear:
            nn.init.normal_(m.weight, std=0.01)
    net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(reduction='none')
    optimizer = optim.SGD(net.parameters(), lr=lr)

    # 训练循环
    for epoch in track(range(num_epochs), description='dropout'):
        for X, y in dataloader_train:
            X, y = X.to(device), y.to(device)
            loss = criterion(net(X), y)
            optimizer.zero_grad()
            loss.mean().backward()
            optimizer.step()

        with torch.no_grad():
            train_loss, train_acc, num_samples = 0.0, 0.0, 0
            for X, y in dataloader_train:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                loss = criterion(y_hat, y)
                train_loss += loss.sum()
                train_acc += (y_hat.argmax(dim=1) == y).sum()
                num_samples += y.numel()
            train_loss /= num_samples
            train_acc /= num_samples

            test_acc, num_samples = 0.0, 0
            for X, y in dataloader_test:
                X, y = X.to(device), y.to(device)
                y_hat = net(X)
                test_acc += (y_hat.argmax(dim=1) == y).sum()
                num_samples += y.numel()
            test_acc /= num_samples

            writer.add_scalars('metrics', {
                'train_loss': train_loss,
                'train_acc': train_acc,
                'test_acc': test_acc
            }, epoch)

    writer.close()

2.2.3. 输出结果

dropout

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/228141.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32-EXTI外部中断

一、中断系统 中断:在主程序运行过程中,出现了特定的中断触发条件(中断源),使得CPU暂停当前正在运行的程序,转而去处理中断程序,处理完成后又返回原来被暂停的位置继续运行 中断优先级&#xff…

2021年第十届数学建模国际赛小美赛B题疾病传播的风险解题全过程文档及程序

2021年第十届数学建模国际赛小美赛 B题 疾病传播的风险 原题再现: 空气传播疾病可以通过咳嗽或打喷嚏、喷洒液体或灰尘传播。另一方面,一些常见的传染病只能通过飞沫传播。请建立一个模型,以评估密闭空间内空气传播和液滴传播疾病的可能性。…

查看电脑cuda版本

1.找到NVODIA控制面板 输入NVIDIA搜索即可 出现NVIDIA控制面板 点击系统信息 2.WINR 输入nvidia-smi 检查了一下,电脑没用过GPU,连驱动都没有 所以,装驱动…… 选版本,下载 下载后双击打开安装 重新输入nvidia-smi 显示如下…

喜讯:加速度商城系统全系列产品品牌全新升级为Shopfa

2月1日讯:经过1年多的品牌文化塑造,深圳市加速度软件开发有限公司经过研究决定,将旗下的多商户商城系列、小程序商城系列、B2B商城系列、供应商集采系列、电子元器件商城系列、跨境独立站商城系列、MRO工业品商城系列、外卖商城系列、智慧零售…

C++笔记:动态内存管理

文章目录 语言层面的内存划分C语言动态内存管理的缺陷new 和 delete 的使用了解语法new 和 delete 操作内置类型new 和 delete 操作自定义类型 new 和 delete 的细节探究new 和 delete 的底层探究operator new 和 operator new[]operator delete 和 operator delete[] 显式调用…

体验官分享 | 用户眼中的OK3588-C开发板究竟有多优秀?

编者荐语:飞凌嵌入式今年共发起了5期【产品体验官】活动,让更多热爱嵌入式的朋友免费体验到了自己感兴趣的产品,飞凌嵌入式也收获了很多宝贵的建议。活动期间体验官们创作了许多优质的体验报告,今天小编就与大家分享一篇来自体验官…

Django讲课笔记01:初探Django框架

文章目录 一、学习目标二、课程导入(一)课程简介(二)课程目标(三)适用人群(四)教学方式(五)评估方式(六)参考教材 三、新课讲授&#…

SAP UI5 walkthrough step2 Bootstrap

我的理解&#xff0c;这就是一个引导指令 1.我们右键打开命令行--执行 ui5 use OpenUI5 2.执行命令&#xff1a;ui5 add sap.ui.core sap.m themelib_sap_horizon 执行完之后&#xff0c;会更新 yaml 文件 3.修改index.html <!DOCTYPE html> <html> <head&…

Java安全之Commons Collections6分析

CC6分析 import org.apache.commons.collections.*; import org.apache.commons.collections.functors.ChainedTransformer; import org.apache.commons.collections.functors.ConstantTransformer; import org.apache.commons.collections.functors.InvokerTransformer; impo…

[⑧ADRV902x]: Digital Pre-Distortion (DPD)学习笔记

前言 DPD 数字预失真技术&#xff0c;是一种用于抑制功率放大器非线性失真的方法。 它通过在信号输入功率放大器&#xff08;PA&#xff09;之前插入一个预失真模块&#xff0c;对输入信号进行适当的调制&#xff0c;以抵消功率放大器引起的非线性失真&#xff0c;使功率放大器…

线上盲盒扭蛋机,开启潮玩新玩法

盲盒近几年非常火爆&#xff0c;因其不确定性、随机性吸引着盲盒爱好者&#xff0c;引起了潮玩文化风潮。扭蛋机是盲盒的一种新抽取模式&#xff0c;线上扭蛋机小程序在具有盲盒的优势外&#xff0c;还具有较大吸引力&#xff0c;用户参与率较高&#xff0c;这也使得扭蛋机成为…

C++_命名空间(namespace)

目录 1、namespace的重要性 2、 namespace的定义及作用 2.1 作用域限定符 3、命名空间域与全局域的关系 4、命名空间的嵌套 5、展开命名空间的方法 5.1 特定展开 5.1 部分展开 5.2 全部展开 结语&#xff1a; 前言&#xff1a; C作为c语言的“升级版”&#xff0c;其在…

西南科技大学C++程序设计实验十一(泛型程序设计与C++标准模板库)

一、实验目的 1. 掌握泛型程序设计概念; 2. 掌握vector、deque、list容器使用方法; 3.了解set、map容器使用方法。 二、实验任务 1.分析完善以下程序,理解vector容器使用方法: #include <iostream> __#include <vector>_______ //补充vector模板头文件 …

(C语言实现)高精度除法 (洛谷 P2005 A/B Problem II)

前言 本期我们分享用C语言实现高精度除法&#xff0c;可通过该题测试点我点我&#xff0c;洛谷 p2005。 那么话不多说我们开始吧。 讲解 大家还记不记得小学的时候我们是怎么做除法的&#xff1f;我们以1115为例。 我们的高精度除法也将采用这个思路进行&#xff0c;分别用两…

selenium 解决 id定位、class定位中,属性值带空格的解决办法

一、前置说明 selenium遇到下面这种元素&#xff1a; <th id"demo id" class"value1 value2 value3 ">1、虽然id一般不会有空格&#xff0c;但是前端错误的这种写法(如下图)&#xff0c;会造成使用id定位不到元素&#xff0c;如&#xff1a; find…

Windows本地如何添加域名映射?(修改hosts文件)

1. DNS(域名系统) Domain Name System(域名系统)&#xff1a;为了加快定位IP地址的速度, 将域名映射进行层层缓存的系统. 目的&#xff1a;互联网通过IP&#xff08;10.223.146.45&#xff09;定位浏览器建立连接&#xff0c;但是我们不易区别IP&#xff0c;为了方便用户辨识I…

SAP 后继物料简介

后继物料(Discontinued Part)是SAP系统提供的一项用于物料继承与物料永久性替换的功能。在企业的日常生产业务中,经常会出于技术原因或成本原因进行大批量的物料替换或物料升级。比如说,企业可以用一种可靠性更高的组件替换先前使用的组件,或者出于节省成本的目的,使用一…

Google Gemini Pro 测试

谷歌新发布了Gemini模型&#xff0c;说是吊打GPT4&#xff0c;上手使用了下&#xff0c;感觉一般&#xff0c;没有感觉到预期的吊打。下面是与Bard的对话。

条码生成器与Zint使用

文章目录 目的条形码zint支持条形码种类下载编译qt pro配置code保存条形码目的 1: 了解条形码数据理论知识 2: 了解zint第三方库相关, 如何编译引用到项目中 条形码 条形码(Barcode)一维码 和二维码(QR code)都是用于存储信息的图形化表示方式,通常应用于商品标识、库…

led护眼灯真的能护眼吗?热门口碑好的护眼台灯推荐

2023年我国青少年近视率为53.6%&#xff0c;其中6岁以下儿童近视率为14.5%&#xff0c;小学生近视率为36%&#xff0c;初中生近视率为71.6%&#xff0c;高中生近视率为81%&#xff0c;大学生近视率为90%。 也就是孩子上了大学&#xff0c;十个里面有九个是近视&#xff0c;不戴…