【深度学习实战(33)】训练之model.train()和model.eval()

一、model.train(),model.eval()作用?

model.train() 和 model.eval() 是 PyTorch 中的两个方法,用于设置模型的训练模式和评估模式。

model.train() 方法将模型设置为训练模式。在训练模式下,模型会启用 dropout 和 batch normalization 等正则化方法,并且可以计算梯度以进行参数更新,同时还可以追踪梯度计算的图。训练时,均值、方差分别是该批次内数据相应维度的均值与方差

model.eval() 方法将模型设置为评估模式。在评估模式下,模型会禁用 dropout 和 batch normalization 等正则化方法,这样可以保证每次评估的结果是确定的。评估模式下的模型通常用于模型的测试、验证或推理阶段。推理时,均值、方差是基于所有批次的期望计算所得

区分训练模式和评估模式的目的在于保证模型在不同阶段的行为一致性。例如,在训练模式下,模型需要计算并追踪梯度以进行反向传播和参数更新;而在评估模式下,模型不需要计算梯度,只需要给出确定的预测结果。

二、model.train(),model.eval()对dropout产生的影响

使用model.train():有神经元被置零,且比例符合nn.Dropout(0.5)中的0.5设定

import torch
import torch.nn as nn

model = nn.Dropout(0.5)
model.train()
input = torch.rand([3, 4])

print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述
使用model.eval():没有神经元置零,nn.Dropout(0.5)被关闭

import torch
import torch.nn as nn

model = nn.Dropout(0.5)
#model.train()
model.eval()
input = torch.rand([3, 4])

print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

不使用model.train()和model.eval():有神经元被置零,但是比例非常随机,不符合nn.Dropout(0.5)中的0.5设定
import torch
import torch.nn as nn

model = nn.Dropout(0.5)
#model.train()
#model.eval()
input = torch.rand([3, 4])

print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

三、model.train(),model.eval()对batch normalization产生的影响

使用model.eval():bn中的均值,方差,不发生改变

# 1.导入所需的库:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# 2.定义数据集的转换方法。MNIST数据集是由28x28像素的手写数字组成的图像,将其转换为torch张量并进行标准化处理:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# 3.下载MNIST数据集并进行转换:
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)

# 4.创建数据加载器:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=0)

# 5.现在你可以使用trainloader和testloader来获取训练集和测试集的批次数据了。例如,可以使用迭代器遍历数据集中的批次:
#dataiter = iter(trainloader)
#images, labels = dataiter.next()

# 上述代码将返回一个批次的图像和对应的标签。可以使用images和labels来进行模型的训练和评估。
# 这就是使用torch库自带的MNIST数据集的基本流程。根据需要,你还可以添加其他的数据处理和增强步骤。


# 定义模型
class Model(nn.Module):
    def __init__(self, hidden_num=32, out_num=10):
        super().__init__()
        self.fc1 = nn.Linear(28*28, hidden_num)
        self.bn  = nn.BatchNorm1d(hidden_num)
        self.fc2 = nn.Linear(hidden_num, out_num)
        self.softmax = nn.Softmax()
    def forward(self, inputs, **kwargs):
        x = inputs.flatten(1)
        x = self.fc1(x)
        
        print("========= bn之前存的数据: =========")
        print(self.bn.running_mean, self.bn.running_var)
        print()
        

        print("========= 当前 Batch 的数据: =========")
        x_mean = torch.mean(x,0)
        x_variance = torch.mean((x - x_mean)*(x - x_mean),0)
        print(x_mean, x_variance)
        print()
        

        print("========= torch官方计算之后的bn新数据: =========")
        x = self.bn(x)
        print(self.bn.running_mean, self.bn.running_var)
        print()
        
        # x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x
    
torch.manual_seed(1)
model = Model()
#model.train()
model.eval()
for img, label in trainloader:
    label = nn.functional.one_hot(label.flatten(), 10)
    out = model(img)
    break

在这里插入图片描述
使用model.train():bn中的均值,方差,通过滑动平均地方式发生改变,

torch.manual_seed(1)
model = Model()
model.train()
#model.eval()
for img, label in trainloader:
    label = nn.functional.one_hot(label.flatten(), 10)
    out = model(img)
    break

在这里插入图片描述
不使用model.train()和model.eval():默认bn中的均值,方差,通过滑动平均地方式发生改变,
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/601926.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

|Python新手小白中级教程|第二十三章:列表拓展之——元组

文章目录 前言一、列表复习1.索引、切片2.列表操作字符3.数据结构实践——字典 二、探索元组1.使用索引、切片2.使用__add__((添加元素,添加元素))3.输出元组4.使用转化法删除元组指定元素5.for循环遍历元组 三、元组VS列表1.区别2.元组(tuple&#xff0…

零门槛副业兼职!10种长期赚钱好方法!

想要实现财务自由,不能仅停留在梦想层面,更需要付诸实践。 以下是我从网络上精心整理的十大可靠的兼职副业建议,旨在助你一臂之力。 这些项目已根据推荐程度、难度水平、目标人群以及预期收入进行了细致分类。 我要强调的是,任…

Cosmo Bunny Girl

可爱的宇宙兔女郎的3D模型。用额外的骨骼装配到Humanoid上,Apple混合了形状。完全模块化,包括不带衣服的身体。 技术细节 内置,包括URP和HDRP PDF。还包括关于如何启用URP和HDRP的说明。 LOD 0:面:40076,tris 76694,verts 44783 装配了Humanoid。添加到Humanoid中的其他…

带EXCEL附件邮件发送相关代码

1.查看生成的邮件 2.1 非面向对象的方式(demo直接copy即可) ​ REPORT Z12. DATA: IT_DOCUMENT_DATA TYPE SODOCCHGI1,IT_CONTENT_TEXT TYPE STANDARD TABLE OF SOLISTI1 WITH HEADER LINE,IT_PACKING_LIST TYPE TABLE OF SOPCKLSTI1 WITH HEADER LIN…

AI编码工具-通义灵码初识

AI编码工具-通义灵码初识 通义灵码支持环境及语言代码安全性 通义灵码安装通义灵码登录 关于通义灵码的初识,还是得从2023云栖大会来说起。2023云栖大会带来了跨越式升级的千亿级参数规模大模型——通义千问2.0,随之而来的便有通义灵码,那么什…

SpringBoot项目配置HTTPS接口的安全访问

参考:https://blog.csdn.net/weixin_45355769/article/details/131727935 安装好openssl后, 创建 D:\certificate CA文件夹下包含: index.txt OpenSSL在创建自签证书时会向该文件里写下索引database.txt OpenSSL会模拟数据库将一些敏感信息…

嵌入式移植7Z解压缩(纯C)

本文分享一个纯C语言编写的7Z解压缩代码库,本代码库的主要目的是在嵌入式环境下使用7z解压缩文件,可以将升级包通过7z进行压缩,然后发送给设备,减小和设备传输过程中的文件大小,进而达到传输大文件的目的。 下载链接 …

2024年最佳音频处理软件盘点!助你事半功倍

在数字媒体时代,音频处理软件已经成为音乐制作、音频编辑和后期处理不可或缺的工具。这些软件具备强大的功能,能帮助用户轻松实现声音的剪辑、混音、特效处理以及音频格式转换等操作。本文将为你介绍音频处理软件的基本概念、功能特点以及常用软件&#…

【@ohos.events.emitter (Emitter)】

ohos.events.emitter (Emitter) 本模块提供了在同一进程不同线程间,或同一进程同一线程内,发送和处理事件的能力,包括持续订阅事件、单次订阅事件、取消订阅事件,以及发送事件到事件队列的能力。 说明: 本模块首批接…

实时Linux对EtherCAT工业自动化协议的支持

在自动化技术和工业控制领域,实时通信网络的重要性不断增长。EtherCAT(Ethernet for Control Automation Technology)作为一种高效的工业以太网通信协议,因其出色的性能和灵活性而广受欢迎。而实时Linux作为影响最为广泛的开源实时…

【Web前端】盒子模型_元素分类_表格

1、盒子模型 1.1简介 CSS盒子模型是在网页设计中经常用到的CSS技术所使用的一种思维模型。包括内容(content)、内边距(padding)、边框(border)、外边距(margin) 1.2边框(border) 1.2.1简介 边框是环绕内容区和填充的边界。边框的属性有border-style、…

Pytorch 实现情感分析

情感分析 情感分析是 NLP 一种应用场景,模型判断输入语句是积极的还是消极的,实际应用适用于评论、客服等多场景。情感分析通过 transformer 架构中的 encoder 层再加上情感分类层进行实现。 安装依赖 需要安装 Poytorch NLP 相关依赖 pip install t…

免费SSL证书?轻松申请攻略来了!

在当今的互联网时代,网络安全已经成为一个不容忽视的重要课题。随着在线交流和交易活动的增加,保护网站和用户信息的重要性日益突显。SSL证书,即安全套接字层证书,它为互联网通信提供了加密服务,确保数据的安全性和完整…

光伏远动通讯屏的组成

光伏远动通讯屏的组成 远动通讯屏主要用于电力系统数据采集与转发,远动通讯屏能够采集站内的各种数据,如模拟量、开关量和数字量等,并通过远动通讯规约将必要的数据上传至集控站或调度系统。这包括但不限于主变和输电线路的功率、电流、电压等…

怎么设置一天多个时间点的闹钟提醒?

在日常生活中,我们经常需要在一天的不同时间点完成特定的任务,如定时喝水、定时查看后台数据、定时吃药等。这时候,如果能有一款软件,可以在一条日程里轻松设置多个时间点的闹钟提醒,那将大大提高我们的工作效率和生活…

如何理解GTX接收通道相关模块?(高速收发器三)

前文讲解了GTX的时钟及发送通道相关内容,本文讲解GTX接收通道的一些功能及其IP配置,接收往往比发送设计更难,与调制解调,加密解密其实相差不大,后者难度都比前者高出很多。GTX的接收通道的功能相比发送通道更加重要&am…

西奥CHT-01软胶囊硬度测试仪:重塑行业标杆,引领硬度测试新纪元

西奥CHT-01软胶囊硬度测试仪:重塑行业标杆,引领硬度测试新纪元 在当今医药领域,软胶囊作为一种广泛应用的药品剂型,其品质的稳定性和安全性直接关系到患者的健康。而在确保软胶囊品质的各项指标中,硬度测试尤为关键。…

AIGC实战——多模态模型DALL.E 2

AIGC实战——多模态模型DALL.E 2 0. 前言1. 模型架构2. 文本编码器3. CLIP4. 先验模型4.1 自回归先验模型4.2 扩散先验模型 5. 解码器5.1 GLIDE5.2 上采样器 6. DALL.E 2 应用6.1 图像变体6.2 先验模型的重要性6.3 DALL.E 2 限制 小结系列链接 0. 前言 DALL.E 2 是 OpenAI 设计…

领域驱动设计架构演进

领域驱动设计由于其强调对领域的深入理解和关注业务价值,其架构演进依赖于领域的变化和特定领域中的技术实践。 初始阶段 一个单体架构,所有的功能都集成在一个应用程序中,领域模型可能还不完全清晰,甚至并未形成。这个阶段主要是为了验证产品的可行性,快速迭代并尽快推…

有没有国内个人可用的GPT平替?推荐5个AI工具

随着AI技术的快速发展,AI写作正成为创作的新风口。但是面对GPT-4这样的国际巨头,国内很多小伙伴往往望而却步,究其原因,就是它的使用门槛高,还有成本的考量。 不过,随着GPT技术的火热,国内也涌…