动手学深度学习(Pytorch版)代码实践 -深度学习基础-11暂退法Dropout

11暂退法Dropout

#Dropout 是一种正则化技术,主要用于防止过拟合,
#通过在训练过程中随机丢弃神经元来提高模型的泛化能力。
import torch
from torch import nn
from d2l import torch as d2l
import liliPytorch as lp

def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1

    #该情况下,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    
    #该情况下所有元素都被保留
    if dropout == 0:
        return X
 
    mask = (torch.rand(X.shape) > dropout).float()
    """
    生成一个与 X 形状相同的掩码张量 mask。torch.rand(X.shape) 生成一个元素值在 [0, 1) 范围内的均匀分布的随机张量。
    mask 中的每个元素与 dropout 进行比较,若大于 dropout 则为 1(保留),否则为 0(丢弃)。最后将布尔值转换为浮点数。
    """
    #将 mask 和 X 元素逐位相乘,以应用掩码效果,即丢弃部分神经元。
    #为了保持输出的期望值不变,结果除以 (1.0 - dropout) 进行缩放补偿。
    return mask * X / (1.0 - dropout)


#测试dropout_layer函数
X= torch.arange(16, dtype = torch.float32).reshape((2, 8))
print(X)
print(dropout_layer(X, 0.))
print(dropout_layer(X, 0.5))
print(dropout_layer(X, 1.))
"""
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11., 12., 13., 14., 15.]])
tensor([[ 0.,  0.,  4.,  6.,  0.,  0., 12.,  0.],
        [16.,  0., 20.,  0.,  0., 26., 28.,  0.]])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

"""
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256

dropout1, dropout2 = 0.2, 0.5
#dropout1:第一个隐藏层之后的 dropout 比例为 20%。
#dropout2:第二个隐藏层之后的 dropout 比例为 50%。

class Net(nn.Module):
    def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
                 is_training = True):
        super(Net, self).__init__()
        self.num_inputs = num_inputs
        self.training = is_training
        self.lin1 = nn.Linear(num_inputs, num_hiddens1)
        self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
        self.lin3 = nn.Linear(num_hiddens2, num_outputs)
        self.relu = nn.ReLU()

    def forward(self, X):
        #将输入张量X重塑为 (batch_size, num_inputs) 的形状。
        H1 = self.relu(self.lin1(X.reshape((-1, self.num_inputs))))
        # 只有在训练模型时才使用dropout
        if self.training == True:
            # 在第一个全连接层之后添加一个dropout层
            H1 = dropout_layer(H1, dropout1)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            # 在第二个全连接层之后添加一个dropout层
            H2 = dropout_layer(H2, dropout2)
        out = self.lin3(H2)
        return out

net = Net(num_inputs, num_outputs, num_hiddens1, num_hiddens2)

num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss(reduction='none')
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)
# lp.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
# d2l.plt.show() 

#Dropout 简洁实现
net = nn.Sequential(
    nn.Flatten(),#它会将多维的输入张量展平成一维。
    nn.Linear(784,256),
    nn.ReLU(),

    #在第一个全连接层之后添加一个dropout层
    nn.Dropout(dropout1),
    nn.Linear(256,256),
    nn.ReLU(),

    #在第二个全连接层后添加一个dropout层
    nn.Dropout(dropout2),
    nn.Linear(256,10)
)

#函数接受一个参数 m,通常是一个神经网络模块(例如,线性层,卷积层等)
def init_weights(m):
#这行代码检查传入的模块 m 是否是 nn.Linear 类型,即线性层(全连接层)
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
#m.weight 是线性层的权重矩阵。
#std=0.01 指定了初始化权重的标准差为 0.01,表示权重将从均值为0,标准差为0.01的正态分布中随机采样。

#model.apply(init_weights) 会遍历模型的所有模块,并对每个模块调用 init_weights 函数。
#如果模块是 nn.Linear 类型,则初始化它的权重。
net.apply(init_weights)

trainer = torch.optim.SGD(net.parameters(), lr = lr)
lp.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show() 

运行结果:

epoch: 1,train_loss: 1.1632388224283854,train_acc: 0.55005,test_acc: 0.7137
<Figure size 350x250 with 1 Axes>
epoch: 2,train_loss: 0.5765969015757243,train_acc: 0.7862833333333333,test_acc: 0.7971
<Figure size 350x250 with 1 Axes>
epoch: 3,train_loss: 0.5013401063283285,train_acc: 0.8177166666666666,test_acc: 0.6976
<Figure size 350x250 with 1 Axes>
epoch: 4,train_loss: 0.46441060066223144,train_acc: 0.8299666666666666,test_acc: 0.837
<Figure size 350x250 with 1 Axes>
epoch: 5,train_loss: 0.4177045190811157,train_acc: 0.8482,test_acc: 0.8348
<Figure size 350x250 with 1 Axes>
epoch: 6,train_loss: 0.4039476199467977,train_acc: 0.8522,test_acc: 0.8376
<Figure size 350x250 with 1 Axes>
epoch: 7,train_loss: 0.38559712861378986,train_acc: 0.8593333333333333,test_acc: 0.8499
<Figure size 350x250 with 1 Axes>
epoch: 8,train_loss: 0.37514646828969317,train_acc: 0.86315,test_acc: 0.8587
<Figure size 350x250 with 1 Axes>
epoch: 9,train_loss: 0.36000535113016763,train_acc: 0.8681166666666666,test_acc: 0.853
<Figure size 350x250 with 1 Axes>
epoch: 10,train_loss: 0.3473748308181763,train_acc: 0.8719333333333333,test_acc: 0.85

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/721437.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

安全宣传咨询日活动向媒体投稿记住这个投稿好方法

在信息爆炸的时代,作为单位的信息宣传员,我肩负着将每一次重要活动,特别是像“安全宣传咨询日”这样的公益活动,有效传达给公众的重任。这份工作看似简单,实则充满了挑战,尤其是在我初涉此领域时,那段曲折而又难忘的投稿经历,至今记忆犹新。 初探投稿之海,遭遇重重困难 起初,我…

这些数据可被Modbus采集,你还不知道???

为什么要用Modbus采集模块 Modbus采集模块之所以被广泛使用&#xff0c;是因为它提供了标准化的通信协议&#xff0c;确保了不同设备间的兼容性。它支持多种通信方式&#xff0c;易于实现&#xff0c;并且能够适应不同的网络环境。Modbus模块能够收集和传输各种工业数据&#x…

【产品经理】订单处理6-审单方案

电商系统中订单管理员会对特殊类型的订单进行审核&#xff0c;普通订单则自动审核&#xff0c;本节讲述自动审单方案、手动审单以及加急审单。 一、自动审单 自动审单方案可按照方案形式制定&#xff0c;可一次性制定多套审单方案。 1. 审单通过条件有 执行店铺&#xff…

大模型的分类:探索多样化的人工智能模型

随着人工智能技术的飞速发展&#xff0c;大型预训练模型&#xff08;以下简称“大模型”&#xff09;已经在自然语言处理、计算机视觉、语音识别等多个领域取得了显著的成果。这些模型通过在海量数据上进行预训练&#xff0c;能够捕捉到丰富的特征信息&#xff0c;为各种下游任…

Linux操作系统学习:day03

内容来自&#xff1a;Linux介绍 视频推荐&#xff1a;[Linux基础入门教程-linux命令-vim-gcc/g -动态库/静态库 -makefile-gdb调试]( 目录 day0317、创建删除目录创建目录删除目录 18、文件的拷贝19、mv 命令20、查看文件内容的相关命令21、给文件创建软连接或硬链接 day03 …

MFC绘制哆啦A梦

OnPaint绘制代码 CPaintDC dc(this); // 用于绘画的设备上下文CRect rc;GetWindowRect(rc);int cxClient rc.Width();int cyClient rc.Height();// 辅助线HPEN hPen CreatePen(PS_DOT, 1, RGB(192, 192, 192));HPEN hOldPen (HPEN)SelectObject(dc, hPen);MoveToEx(dc, cxC…

使用Vue中的<TransitionGroup/>进入动画不生效不显示问题

Vue中有两个过渡动画组件分别是&#xff1a;<TransitionGroup/> <TransitionGroup/>进入动画不生效不显示问题 &#xff0c;在渲染列表上加上v-if&#xff0c;看代码&#xff0c;让他每次渲染都重新渲染 加上v-if即可 <template> <TransitionGroup nam…

Perforce静态代码分析专家解读MISRA C++:2023®新标准:如何安全、高效地使用基于范围的for循环,防范未定义行为

MISRA C&#xff1a;2023——MISRA C 标准的下一个版本来了&#xff01;为了帮助您了解 MISRA C&#xff1a;2023相比于之前版本的变化&#xff0c;我们将继续为您带来Perforce首席技术支持工程师Frank van den Beuken博士的博客系列&#xff0c;本期为第三篇。 在前两篇系列文…

和服务器建立联系——6.10山大软院项目实训1

下面介绍我如何在自己的项目中&#xff0c;根据aigc组的接口&#xff08;如下图&#xff09;&#xff0c;在Unity中和服务器建立联系并发出接受请求的&#xff1a; 这是一个通过HTTP POST方法调用的接口&#xff0c;需要发送JSON格式的数据。在Unity中实现这样的功能&#xff0…

文字炫酷祝福 含魔法代码

效果下图&#xff1a;&#xff08;可自定义显示内容&#xff09; 代码如下&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initi…

SpringBoot + Maven 项目的创建

文章目录 1、Maven2、SpringBoot3、二者之间的联系4、项目的创建 在创建项目之前&#xff0c;肯定要知道他们之间的区别 1、Maven maven是一个跨平台的项目管理工具。它是Apache的一个开源项目&#xff0c;主要服务于基于Java平台的项目构建、依赖管理和项目信息管理。 比如说…

QT day04

一、思维导图 二、登录界面优化 代码&#xff1a; 界面&#xff1a; *{background-color: rgb(255, 255, 255); }QFrame#frame{border-image: url(:/Logo/shanChuan.jpg);border-radius:15px; }#frame_2{background-color: rgba(110, 110, 110, 120);border-radius:15px; }Q…

线代的学习(矩阵)

1.矩阵的乘法 矩阵实现满足&#xff1a;内标相等 矩阵相乘之后的结果&#xff1a;前行后列 需要注意&#xff1a;1.矩阵的乘法不具有交换律&#xff1a;AB!BA 2.矩阵的乘法满足分配律&#xff1a;A(BC) AB AC 抽象逆矩阵求逆矩阵 方法1.凑定义法、 方法2.长除法 数字型矩阵…

一文弄懂 Python os.walk(),轻松搞定文件处理和目录遍历

&#x1f349; CSDN 叶庭云&#xff1a;https://yetingyun.blog.csdn.net/ Python os 模块的 walk() 方法以自顶向下或自底向上的方式遍历指定的目录树&#xff0c;从而显示目录树中的文件名。对于目录树中的每个目录&#xff0c;os.walk() 方法都会产生一个包含目录路径、当前…

vue3第四十节(pinia的用法注意事项解构store)

pinia 主要包括以下五部分&#xff0c;经常用到的是 store、state、getters、actions 以下使用说明&#xff0c;注意事项&#xff0c;仅限于 vue3 setup 语法糖中使用&#xff0c;若使用选项式 API 请直接查看官方文档&#xff1a; 一、前言&#xff1a; pinia 是为了探索 vu…

04-对原生app应用中的元素进行定位

本文介绍对于安卓原生app应用中的元素如何进行定位。 一、uiautomatorviewer uiautomatorviewer是Android-SDK自带的一个元素定位工具&#xff0c;非常简单好用&#xff0c;可以使用该工具查看app应用中的元素属性&#xff0c;帮助我们在代码中进行元素定位。 1&#xff09;使…

Win11版本21H2怎么升级为23H2?升级详细步骤在此!

在Win11电脑操作中&#xff0c;用户目前使用的版本是21H2&#xff0c;现在想体验23H2版本的先进功能&#xff0c;但不知道要怎么操作才能将系统版本升级为23H2&#xff1f;接下来小编给大家介绍详细的升级方法步骤&#xff0c;助力大家轻松完成系统版本升级操作。 方法一&#…

VirtualStudio配置QT开发环境

环境 VirtualStudio2022Qt5.12.10 安装msvc工具链&#xff08;这一步不是必须的&#xff09; 打开virtual studio&#xff0c;打开Virtual Studio Installer界面选择要安装的msvc版本&#xff0c;点击安装 安装VirtualStudio扩展 在线安装 打开virtual Studio&#xff0c;…

ps2024磨皮滤镜插件Portraiture升级版下载-Portraiture2024软件最新版下载附加安装步骤

不少小伙伴在制作了照片后都会通过一些形式进行美化解决&#xff0c;今日小编就给大家详细介绍一款非常不错的专用工具&#xff0c;它是Corel PaintShop Pro 2024 手机软件&#xff0c;此软件为消费者提供了技术专业完备的视频后期制作作用&#xff0c;能够让消费者轻轻松松将为…

批量创建文件夹 就是这么简单 一招创建1000+文件夹

批量创建文件夹 就是这么简单 一招创建1000文件夹 在工作中&#xff0c;或者生活中&#xff0c;我们经常要用到批量创建文件夹&#xff0c;并且根据不同的工作需求&#xff0c;要求是不一样的&#xff0c;比如有些人需要创建上千个不一样名称的文件夹&#xff0c;如果靠手动创…