《动手学深度学习(PyTorch版)》笔记4.6

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过。

Chapter4 Multilayer Perceptron

4.6 Dropout Regularization

4.6.1 Reexamine Overfitting

当面对更多的特征而样本不足时,线性模型往往会过拟合。相反,当给出更多样本而不是特征,通常线性模型不会过拟合。不幸的是,线性模型泛化的可靠性是有代价的,即线性模型没有考虑到特征之间的交互作用。对于每个特征,线性模型必须指定正的或负的权重,而忽略其他特征。

泛化性和灵活性之间的这种基本权衡被描述为偏差-方差权衡(bias-variance tradeoff)。线性模型有很高的偏差,然而,这些模型的方差很低,即它们在不同的随机数据样本上可以得出相似的结果。

深度神经网络位于偏差-方差谱的另一端。与线性模型不同,神经网络并不局限于单独查看每个特征,而是学习特征之间的交互。例如,神经网络可能推断“尼日利亚”和“西联汇款”一起出现在电子邮件中表示垃圾邮件,但单独出现则不表示垃圾邮件。

即使我们有比特征多得多的样本,深度神经网络也有可能过拟合。本节中,我们将探究改进深层网络的泛化性的工具。

4.6.2 Robustness of Disturbances

经典泛化理论认为,为了缩小训练和测试性能之间的差距,应该以简单的模型为目标。简单性以较小维度的形式展现,参数的范数也代表了一种有用的简单性度量。简单性的另一个角度是平滑性,即函数不应该对其输入的微小变化敏感。例如,当我们对图像进行分类时,我们预计向像素添加一些随机噪声应该是基本无影响的。

在训练过程中,我们可以在计算后续层之前向网络的每一层注入噪声。因为当训练一个有多层的深层网络时,注入噪声只会在输入-输出映射上增强平滑性。这个想法被称为暂退法(dropout)。暂退法在前向传播过程中,计算每一内部层的同时注入噪声,这已经成为训练神经网络的常用技术。这种方法之所以被称为暂退法,因为我们从表面上看是在训练过程中丢弃一些神经元。在整个训练过程的每一次迭代中,标准暂退法包括在计算下一层之前将当前层中的一些节点置零。
关键的挑战是如何注入这种噪声。一种想法是以一种无偏向(unbiased)的方式注入噪声。这样在固定住其他层时,每一层的期望值等于没有噪音时的值。我们可以在每次训练迭代中,从均值为零的分布 ϵ ∼ N ( 0 , σ 2 ) \epsilon \sim \mathcal{N}(0,\sigma^2) ϵN(0,σ2)采样噪声添加到输入 x \mathbf{x} x,从而产生扰动点 x ′ = x + ϵ \mathbf{x}' = \mathbf{x} + \epsilon x=x+ϵ,预期是 E [ x ′ ] = x E[\mathbf{x}'] = \mathbf{x} E[x]=x

在标准暂退法正则化中,通过按保留(未丢弃)的节点的分数进行规范化来消除每一层的偏差,如下所示:

h ′ = { 0  概率为  p h 1 − p  其他情况 \begin{aligned} h' = \begin{cases} 0 & \text{ 概率为 } p \\ \frac{h}{1-p} & \text{ 其他情况} \end{cases} \end{aligned} h={01ph 概率为 p 其他情况

根据此模型的设计,其期望值保持不变,即 E [ h ′ ] = h E[h'] = h E[h]=h

4.6.3 Implementation

当我们将暂退法应用到隐藏层,以 p p p的概率将隐藏单元置为零时,结果可以看作一个只包含原始神经元子集的网络。比如在下图中,删除了 h 2 h_2 h2 h 5 h_5 h5,因此输出的计算不再依赖于 h 2 h_2 h2 h 5 h_5 h5,并且它们各自的梯度在执行反向传播时也会消失。这样,输出层的计算不会过度依赖于 h 1 , … , h 5 h_1, \ldots, h_5 h1,,h5的任何一个元素。

在这里插入图片描述

通常,我们在测试时不用暂退法,然而也有一些例外,比如一些研究人员在测试时使用暂退法,用于估计神经网络预测的“不确定性”:如果通过许多不同的暂退法遮盖后得到的预测结果都是一致的,那么我们可以说网络发挥更稳定。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

#以dropout的概率丢弃X中的元素
def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 在本情况中,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 在本情况中,所有元素都被保留
    if dropout == 0:
        return X
    #从均匀分布$U[0, 1]$中抽取样本,使得样本和节点一一对应,然后保留那些对应样本大于p的节点
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

#X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
#print(X)
#print(dropout_layer(X, 0.))
#print(dropout_layer(X, 0.5))
#print(dropout_layer(X, 1.))

#定义模型
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
dropout1, dropout2 = 0.2, 0.5#常见的技巧是在靠近输入层的地方设置较低的暂退概率

class Net(nn.Module):#indicates that Net is inheriting from the nn.Module class
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out


net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

#训练和测试
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
plt.show()

#简洁实现
net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

#训练和测试
trainer = torch.optim.SGD(net.parameters(), lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/354882.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MyBatis详解(3)-- 动态代理及映射器

MyBatis详解&#xff08;3&#xff09; mybatis 动态代理动态代理的规范selectOne和selectListnamespace mybatis映射器映射器的引入&#xff1a; 映射器的组成select 元素结构&#xff1a;单个参数传递多个参数传递 insert 元素结构主键回填&#xff1a;自定义主键生成规则 u …

Linux中查看端口被哪个进程占用、进程调用的配置文件、目录等

1.查看被占用的端口的进程&#xff0c;netstat/ss -antulp | grep :端口号 2.通过上面的命令就可以列出&#xff0c;这个端口被哪些应用程序所占用&#xff0c;然后找到对应的进程PID https://img-blog.csdnimg.cn/c375eb2bed754426b373907acaa7346e.png 3.根据PID查询进程。…

Kafka-服务端-GroupCoordinator

在每一个Broker上都会实例化一个GroupCoordinator对象&#xff0c;Kafka按照Consumer Group的名称将其分配给对应的GroupCoordinator进行管理&#xff1b; 每个GroupCoordinator只负责管理Consumer Group的一个子集&#xff0c;而非集群中全部的Consumer Group。 请注意与Kaf…

数据结构篇-03:堆实现优先级队列

本文着重在于讲解用 “堆实现优先级队列” 以及优先级队列的应用&#xff0c;在本文所举的例子中&#xff0c;可能使用优先级队列来解并不是最优解法&#xff0c;但是正如我所说的&#xff1a;本文着重在于讲解“堆实现优先级队列” 堆实现优先级队列 堆的主要应用有两个&…

OpenCV 2 - 矩阵的掩膜操作

1知识点 1-1 CV_Assert(myImage.depth() == CV_8U); 确保输入图像是无符号字符类型,若该函数括号内的表达式为false,则会抛出一个错误。 1-2 Mat.ptr(int i = 0); 获取像素矩阵的指针,索引 i 表示第几行,从0开始计行数。 1-3 const uchar* current = mylmage.ptr(row); 获得…

React 组件生命周期-概述、生命周期钩子函数 - 挂载时、生命周期钩子函数 - 更新时、生命周期钩子函数 - 卸载时

React 组件生命周期-概述 学习目标&#xff1a; 能够说出组件的生命周期一共几个阶段 组件的生命周期是指组件从被创建到挂在到页面中运行&#xff0c;在到组件不用时卸载组件 注意&#xff1a;只有类组件才有生命周期&#xff0c;函数组件没有生命周期(类组件需要实例化&…

uni-app 开发着突然忘记项目所在位置 教你快速通过HBuilder X定位到项目的位置

我经常会开发着 开发着 就忘记项目在哪了 我们可以用编辑器打开项目 然后右键项目目录 然后选择这个 使用命令行窗口打开所在目录(U) 这样 他就会快速用 本地文件夹 帮你打开这个目录了 还可以 右键项目 选择 使用命令行窗口打开所在目录(U) 下面就会帮你打开这个目录的终端…

腾讯云一键搭建幻兽帕鲁服务器教程

幻兽帕鲁&#xff08;Palworld&#xff09;是一款多人在线游戏&#xff0c;为了获得更好的游戏体验&#xff0c;许多玩家选择自行搭建游戏联机服务器&#xff0c;但是如何搭建游戏联机服务器成为一个难题&#xff0c;腾讯云提供了游戏联机服务器一键部署方案&#xff0c;让大家…

java8 映射方法(map,flatMap)

5.2 映射&#xff08;map&#xff0c;flatMap&#xff09; 一个非常常见的数据处理套路就是 从某些对象中选择信息。比如在SQL里&#xff0c;你可以从表中选择一列。Stream API也通过map和flatMap方法提供了类似的工具。 5.2.1 对流中每一个元素应用函数&#xff08;map&am…

DMA 和 零拷贝技术 到 网络大文件传输优化

文章目录 DMA 控制器的发展无 DMA 控制器 IO 过程DMA 控制器 传统文件传输性能有多糟糕&#xff1f;如何优化文件传输性能零拷贝技术mmap writesendfileSG-DMA&#xff08;The Scatter-Gather Direct Memory Access&#xff09; 零拷贝技术的应用 大文件传输应该用什么方式Pag…

第二百九十二回

文章目录 1. 概念介绍2. 方法与细节2.1 实现方法2.2 具体细节 3. 示例代码4. 内容总结 我们在上一章回中介绍了"如何混合选择图片和视频文件"相关的内容&#xff0c;本章回中将介绍如何混合选择多个图片和视频文件.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1…

OpenGL/C++_学习笔记(四)空间概念与摄像头

汇总页 上一篇: OpenGL/C_学习笔记&#xff08;三&#xff09; 绘制第一个图形 OpenGL/C_学习笔记&#xff08;四&#xff09;空间概念与摄像头 空间概念与摄像头前置科技树: 线性代数空间概念流程简述各空间相关概念详述 空间概念与摄像头 前置科技树: 线性代数 矩阵/向量定…

毕业设计过程学习

传统的目标检测算法主要通过人工设计与纹理、颜色和形状相关的特征来进行目标区域特征的提取。随着深度学习和人工智能技术的飞速发展&#xff0c;目标检测技术也取得了很大的成就。早期基于深度学习的目标检测算法的研究方向仍然是将目标定位任务和图像分类任务分离开来的&…

1 月 27日算法练习-贪心

文章目录 扫地机器人分糖果最小战斗力差距谈判纪念品分组 扫地机器人 思路&#xff1a; 最优机器人清理方法&#xff1a;机器人清理方法先扫左边&#xff0c;有时间再扫右边。最短时间&#xff1a;通过枚举&#xff0c;从 1 开始&#xff0c;清理面积会越大直到全部面积的清理…

深入理解C语言(3):自定义类型详解

文章主题&#xff1a;结构体类型详解&#x1f30f;所属专栏&#xff1a;深入理解C语言&#x1f4d4;作者简介&#xff1a;更新有关深入理解C语言知识的博主一枚&#xff0c;记录分享自己对C语言的深入解读。&#x1f606;个人主页&#xff1a;[₽]的个人主页&#x1f3c4;&…

事务:分布式事务与本地事务的区别

分布式事务章节 分布式事务&#xff1a;2PC与3PC的区别-CSDN博客 分布式事务&#xff1a;X/Open DTP分布式事务处理模型与分布式事务处理XA规范-CSDN博客 事务简介 事务(Transaction)是操作数据库中某个数据项的一个程序执行单元(unit)。事务是由一组操作构成的可靠的独立的…

[SWPUCTF 2018]SimplePHP1

打开环境 有查看文件跟上传文件&#xff0c;查看文件里面显示没有文件url貌似可以文件读取 上传文件里面可以上传文件。 先看一下可不可以文件读取 /etc/passwd不能读取&#xff0c;源码提示flag在f1ag.php 看看能不能读取当前的文件&#xff0c; 先把代码摘下来 file.php …

LPC系列一个定时器不同频率

1.背景 最近研究的LPC804里只有一个ctimer&#xff0c;很多时候用的捉襟见肘的&#xff0c;官方给了一份双匹配的参考例程&#xff0c;不过实际用处不大。不过我花了一晚上的时间&#xff0c;终于研究出来将一个定时器拆成四个定时器用的办法了。这个方法适用于用回调函数的LP…

Fastbee物联网项目新手快速入门

一&#xff0c;前提条件 后端环境准备如下&#xff1a; 正式环境推荐硬件资源最低要求4c8G&#xff0c;硬盘40G。JDK 1.8.0_2xx (需要小版本号大于200) 。Maven3.6.3。&#xff08;IDEA启动时使用IDEA默认自带的版本即可&#xff09;。 启动fastbee之前&#xff0c;请先确定…

go语言(十七)----json

1、结构体转json package mainimport ("encoding/json""fmt" )type Movie struct{Title string json:"title"Year int json:"year"Price int json:"rmb"Actors []string json:"actors" }func main() {movie : Mo…