PyTorch之nn.Module与nn.functional用法区别

文章目录

  • 1. nn.Module
  • 2. nn.functional
    • 2.1 基本用法
    • 2.2 常用函数
  • 3. nn.Module 与 nn.functional
    • 3.1 主要区别
    • 3.2 具体样例:nn.ReLU() 与 F.relu()
  • 参考资料

1. nn.Module

在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module 并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。

关于 nn.Module 的更多介绍可以参考博客:PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解

这里,我们基于nn.Module创建一个简单的神经网络模型,实现代码如下:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.layer2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.layer2(x)
        return x

2. nn.functional

nn.functional 是PyTorch中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module 不同,nn.functional 中的函数不具有可学习的参数。这些函数通常用于执行各种非线性操作、损失函数、激活函数等。

2.1 基本用法

如何在神经网络中使用nn.functional?

在PyTorch中,你可以轻松地在神经网络中使用 nn.functional 函数。通常,你只需将输入数据传递给这些函数,并将它们作为网络的一部分。

以下是一个简单的示例,演示如何在一个全连接神经网络中使用ReLU激活函数:

import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

在上述示例中,我们首先导入nn.functional 模块,然后在网络的forward 方法中使用F.relu 函数作为激活函数。

nn.functional 的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。

2.2 常用函数

(1)激活函数

激活函数是神经网络中的关键组件,它们引入非线性性,使网络能够拟合复杂的数据。以下是一些常见的激活函数:

  • ReLU(Rectified Linear Unit)
    ReLU是一种简单而有效的激活函数,它将输入值小于零的部分设为零,大于零的部分保持不变。它的数学表达式如下:
output = F.relu(input)
  • Sigmoid
    Sigmoid函数将输入值映射到0和1之间,常用于二分类问题的输出层。它的数学表达式如下:
output = F.sigmoid(input)
  • Tanh(双曲正切)
    Tanh函数将输入值映射到-1和1之间,它具有零中心化的特性,通常在循环神经网络中使用。它的数学表达式如下:
output = F.tanh(input)

(2)损失函数

损失函数用于度量模型的预测与真实标签之间的差距。PyTorch的nn.functional 模块包含了各种常用的损失函数,例如:

  • 交叉熵损失(Cross-Entropy Loss)
    交叉熵损失通常用于多分类问题,计算模型的预测分布与真实分布之间的差异。它的数学表达式如下:
loss = F.cross_entropy(input, target)
  • 均方误差损失(Mean Squared Error Loss)
    均方误差损失通常用于回归问题,度量模型的预测值与真实值之间的平方差。它的数学表达式如下:
loss = F.mse_loss(input, target)
  • L1 损失
    L1损失度量预测值与真实值之间的绝对差距,通常用于稀疏性正则化。它的数学表达式如下:
loss = F.l1_loss(input, target)

(3)非线性操作

nn.functional 模块还包含了许多非线性操作,如池化、归一化等。

  • 最大池化(Max Pooling)
    最大池化是一种用于减小特征图尺寸的操作,通常用于卷积神经网络中。它的数学表达式如下:
output = F.max_pool2d(input, kernel_size)
  • 批量归一化(Batch Normalization)
    批量归一化是一种用于提高训练稳定性和加速收敛的技术。它的数学表达式如下:
output = F.batch_norm(input, mean, std, weight, bias)

3. nn.Module 与 nn.functional

3.1 主要区别

nn.Module 与 nn.functional 的主要区别在于:

  • nn.Module实现的layers是一个特殊的类,都是由class Layer(nn.Module)定义,会自动提取可学习的参数;
  • nn.functional中的函数更像是纯函数,由def function(input)定义。

注意:

  1. 如果模型有可学习的参数时,最好使用nn.Module。
  2. 激活函数(ReLU、sigmoid、Tanh)、池化(MaxPool)等层没有可学习的参数,可以使用对应的functional函数。
  3. 卷积、全连接等有可学习参数的网络建议使用nn.Module。
  4. dropout没有可学习参数,但建议使用nn.Dropout而不是nn.functional.dropout。

3.2 具体样例:nn.ReLU() 与 F.relu()

nn.ReLU() :

import torch.nn as nn
'''
nn.ReLU()

F.relu():

import torch.nn.functional as F
'''
out = F.relu(input)

其实这两种方法都是使用relu激活,只是使用的场景不一样,F.relu()是函数调用,一般使用在foreward函数里。而nn.ReLU()是模块调用,一般在定义网络层的时候使用。

当用print(net)输出时,nn.ReLU()会有对应的层,而F.ReLU()是没有输出的。

import torch.nn as nn
import torch.nn.functional as F

class NET1(nn.Module):
    def __init__(self):
        super(NET1, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU()  # 模块的激活函数

    def forward(self, x):
        out = self.conv(x)
        x = self.bn(x)
        out = self.relu()
        return out


class NET2(nn.Module):
    def __init__(self):
        super(NET2, self).__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn = nn.BatchNorm2d(16)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        out = F.relu(x)  # 函数的激活函数
        return out


net1 = NET1()
net2 = NET2()
print(net1)
print(net2)

在这里插入图片描述

参考资料

  • PyTorch的nn.Module类的详细介绍
  • PyTorch nn.functional 模块详解:探索神经网络的魔法工具箱
  • pytorch:F.relu() 与 nn.ReLU() 的区别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/767900.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大数据------JavaWeb------JSP(完整知识点汇总)

JSP 定义 JSP(Java Server Pages),即Java服务端页面。它是一种动态的网页技术,其中可以定义HTML、CSS、JS等静态内容,还可以定义Java代码的动态内容JSP HTML Java 说白了JSP就是一个页面,它既可以写HTML标…

【每日刷题】Day79

【每日刷题】Day79 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 1619. 删除某些元素后的数组均值 - 力扣(LeetCode) 2. 1365. 有多少小于当前…

Python UUID模块:深入理解与使用技巧

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

Spark入门教程(非常详细)从零基础入门到精通,看完这一篇就够了

文章目录 引言1. Spark 基础 1.1 Spark 为何物1.2 Spark VS Hadoop1.3 Spark 优势及特点 1.3.1 优秀的数据模型和丰富计算抽象1.3.2 完善的生态圈-fullstack1.3.3 spark的特点 1.4 Spark 运行模式 2. Spark Core 2.1 RDD详解 2.1.1 RDD概念2.1.2 RDD属性2.1.3 RDD API 2.1.3.1…

还有人不会挑智能猫砂盆?详细测评热门品牌糯雪、空气萝卜、CEWEY!

在现代家居生活中,宠物已成为许多家庭不可或缺的一员,而猫砂盆作为猫咪日常如厕的重要工具,选择什么类型的智能猫砂盆更是关乎猫咪健康与生活质量的关键。而市面上的智能猫砂盆品类众多,令人在挑选的时候眼花缭乱,不知…

监控平台zabbix对接grafana

目录 1.安装grafana并启动 2.浏览器访问 3.导入zabbix数据,对接grafana 4.如何导入模板 5.使用zabbix监控nginx并发量连接数 5.1 修改nginx配置 5.2 编写监控数据脚本 5.3 设置键值 5.4 在zabbix web端完成自定义监控项 5.5 连接到grafana 以上一篇博客&l…

GCN结合Transformer炸场!性能暴涨74%,效率翻3倍

最近发现了两篇效果很妙的GCN结合Transformer的最新工作,分享给大家: MP-GT:通过结合GCN和Transformer方法来增强App使用预测的准确性,实现了74.02%的性能提升,且训练时间减少了79.47%。 MotionAGFormer:结…

Dubbo简介

Apache Dubbo是一款高性能、轻量级的开源服务框架。 1.单体架构 比如现在有一个学生成绩管理平台,里面有学生管理,教师管理,成绩管理。然后将这个系统打包上线,部署在一个2核4G的服务器上,但是现在用户对成绩管理模块…

Shell Expect自动化交互(示例)

Shell Expect自动化交互 日常linux运维时,经常需要远程登录到服务器,登录过程中需要交互的过程,可能需要输入yes/no等信息,所以就用到expect来实现交互。 关键语法 ❶[#!/usr/bin/expect] 这一行告诉操…

民宿小程序开发,在线预订模式

一、开发背景 如今,随着互联网技术的快速发展,大众的生活消费都集中在了手机上,通过手机进行各种活动,同时也包括了预订酒店民宿,由此,民宿预约小程序出现在了大众的生活中。 二、民宿小程序特点 民宿小…

怎么参与场外期权?

今天期权懂带你了解怎么参与场外期权? 目前个人投资者暂时还不能直接参与场外个股期权,因为场外个股期权现在只能机构来进行交易。 所以个人投资者目前只能通过机构通道来进行操作,类似期权懂,找到期权懂经理,然后通…

深入浅出:C语言线程以及线程锁

目录 线程和线程锁概念 线程锁的概念 线程的特点 线程的使用 创建线程 pthread_create 回收线程pthread_join 退出线程 pthread_exit 线程锁的使用 线程同步之互斥锁(Mutex) 初始化互斥锁 获取互斥锁 释放互斥锁 销毁互斥锁 初始化条件变量…

SSMOA办公系统-计算机毕业设计源码19159

摘 要 随着现代信息技术的快速发展以及企业规模不断扩大,实现办公线上流程自动化已成为提升企业核心竞争力的关键。本文主要介绍的是利用Spring、SpringMVC和MyBatis(简称为:SSM)框架,MySQL数据库等先进的互联网开源技…

X86 +PC104+支持WinCE5.0,WinCE6.0,DOS,WinXP, QNX等操作系统,工业控制数据采集核心模块板卡定制

CPU 模块 是一款基于RDC 3306的SOM Express模块。RDC 3306这款X86架构的CPU是一款性能高、稳定性强的处理器。 它是一款灵活精巧的主板(尺寸为91.8mm68.6mm),可以灵活的运用于用户的底板,节约开发成本。模块的接插件使用插针形式…

基于PHP花涧订购系统的设计与实现-计算机毕业设计源码00332

摘 要 近年来,电子商务的快速发展引起了行业和学术界的高度关注。花涧订购系统旨在为用户提供一个简单、高效、便捷的花卉购物体验,它不仅要求用户清晰地查看所需信息,而且还要求界面设计精美,使得功能与页面完美融合,…

固定网国内数据传送业务经营许可证

一、国内固定网数据传送业务是什么? 固定网国内数据传送业务是指互联网数据传送业务以外的,在固定网中以有线方式提供的国内端到端数据传送业务。主要包括基于IP承载网、ATM网、X.25分组交换网、DDN网、帧中继网络的数据传送业务等。该业务属于A2类基础…

Vision Transformer论文阅读笔记

目录 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale -- Vision Transformer摘要Introduction—简介RELATED WORK—相关工作METHOD—方法VISION TRANSFORMER (VIT)—视觉Transformer(ViT) 分析与评估PRE-TRAINING DATA REQUIREMENTS—预训练数据…

PDF文档如何统计字数,统计PDF文档字数的方法有哪些?

在平时使用pdf阅读或者是处理文档的时候,常常需要统计文档的字数。pdf在查看文字时其实很简单。 PDF文档是一种常见的电子文档格式,如果需要对PDF文档中的字数进行统计,可以使用以下方法: Adobe Acrobat DC:Adobe Ac…

【服装识别系统】图像识别+Python+人工智能+深度学习+算法模型+TensorFlow

一、介绍 服装识别系统,本系统作为图像识别方面的一个典型应用,使用Python作为主要编程语言,并通过TensorFlow搭建ResNet50卷积神经算法网络模型,通过对18种不同的服装(‘黑色连衣裙’, ‘黑色衬衫’, ‘黑色鞋子’, …

从 ClickHouse 到 Apache Doris:快成物流的数智化货运应用实践

导读:随着快成物流的大宗商品产业链的不断发展,货运轨迹规划和实时数据分析的需求日益迫切,为了保障数据报表更新、用户画像圈选与物流轨迹实时更新等大数据核心系统性能,快成物流引入 Apache Doris 实时数仓升级了大数据算法平台…