PyTorch使用教程(6)一文讲清楚torch.nn和torch.nn.functional的区别

torch.nn torch.nn.functional 在 PyTorch 中都是用于构建神经网络的重要组件,但它们在设计理念、使用方式和功能上存在一些显著的区别。以下是关于这两个模块的详细区别:

1. 继承方式与结构

torch.nn

  • torch.nn 中的模块大多数是通过继承 torch.nn.Module 类来实现的。这些模块都是 Python 类,包含了神经网络的各种层(如卷积层、全连接层等)和其他组件(如损失函数、优化器等)。
  • torch.nn 中的模块可以包含可训练参数,如权重和偏置,这些参数在训练过程中会被优化。

torch.nn.functional

  • torch.nn.functional 中的函数是直接调用的,无需实例化。这些函数通常用于执行各种非线性操作、损失函数计算、激活函数应用等。
  • torch.nn.functional 中的函数没有可训练参数,它们只是执行操作并返回结果。

2. 实现方式与调用方式

torch.nn

  • torch.nn 中的模块是基于面向对象的方法实现的。开发者需要创建类的实例,并在类的 forward 方法中定义数据的前向传播路径。
  • torch.nn 中的模块通常需要先创建模型实例,再将输入数据传入模型中进行前向计算。

torch.nn.functional

  • torch.nn.functional 中的函数是基于函数式编程实现的。它们提供了灵活的接口,允许开发者以函数调用的方式轻松定制和扩展神经网络架构。
  • torch.nn.functional 中的函数可以直接调用,只需要将输入数据传入函数中即可进行前向计算。

3. 使用场景与优势

torch.nn

  • torch.nn 更适合用于定义有状态的模块,如包含可训练参数的层。
  • 当定义具有变量参数的层时(如卷积层、全连接层等),torch.nn 会帮助初始化好变量,并且模型类本身就是 nn.Module 的实例,看起来会更加协调统一。
  • torch.nn 可以结合 nn.Sequential 来简化模型的构建过程。

torch.nn.functional

  • torch.nn.functional 中的函数相比 torch.nn 更偏底层,封装性不高但透明度很高。开发者可以在其基础上定义出自己想要的功能。
  • 使用 torch.nn.functional 可以更方便地进行函数组合、复用等操作,适合那些喜欢使用函数式编程风格的开发者。当激活函数只需要在前向传播中使用时,使用 torch.nn.functional 中的激活函数会更加简洁。

4. 权重与参数管理

torch.nn

  • torch.nn 中的模块会自动管理权重和偏置等参数,这些参数可以通过 model.parameters() 方法获取,并用于优化算法的训练。

torch.nn.functional

  • torch.nn.functional 中的函数不直接管理权重和偏置等参数。如果需要使用这些参数,开发者需要在函数外部定义并初始化它们,然后将它们作为参数传入函数中。

5.举例说明

例子1:定义卷积层

使用 torch.nn

import torch.nn as nn

class MyConvNet(nn.Module):
    def __init__(self):
        super(MyConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        return x

# 实例化模型
model = MyConvNet()

# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as F

def my_conv_net(input_tensor, weight, bias=None):
    output_tensor = F.conv2d(input_tensor, weight, bias=bias, stride=1, padding=1)
    return output_tensor

# 定义卷积核的权重和偏置
weight = nn.Parameter(torch.randn(16, 1, 3, 3))
bias = nn.Parameter(torch.randn(16))

# 传入输入数据
input_tensor = torch.randn(1, 1, 32, 32)
output_tensor = my_conv_net(input_tensor, weight, bias)

在这个例子中,使用 torch.nn 定义了一个包含卷积层的模型类,而使用 torch.nn.functional 则是通过函数直接进行卷积操作。注意在使用 torch.nn.functional 时,需要手动定义和传递卷积核的权重和偏置。

例子2:应用激活函数

使用 torch.nn

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(x)
        return x

# 实例化模型
model = MyModel()

# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

使用 torch.nn.functional

import torch.nn.functional as F

def my_model(input_tensor):
    output_tensor = F.relu(input_tensor)
    return output_tensor

# 传入输入数据
input_tensor = torch.randn(1, 10)
output_tensor = my_model(input_tensor)

在这个例子中,使用 torch.nn 定义了一个包含 ReLU 激活函数的模型类,而使用 torch.nn.functional 则是通过函数直接应用 ReLU 激活函数。

例子3:定义和计算损失

使用 torch.nn

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        x = self.linear(x)
        return x

# 实例化模型
model = MyModel()

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()

# 前向传播和计算损失
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target)

使用 torch.nn.functional

import torch.nn.functional as F

def my_model(input_tensor):
    output_tensor = torch.matmul(input_tensor, weight.t()) + bias
    return output_tensor

# 定义权重和偏置
weight = nn.Parameter(torch.randn(10, 2))
bias = nn.Parameter(torch.randn(2))

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 传入输入数据和标签
input_tensor = torch.randn(1, 10)
target = torch.tensor()

# 前向传播和计算损失
output_tensor = my_model(input_tensor)
loss = criterion(output_tensor, target)

在这个例子中,使用 torch.nn 定义了一个包含全连接层的模型类,并使用了 torch.nn 中的损失函数来计算损失。而使用 torch.nn.functional 则是通过函数直接进行线性变换,并使用 torch.nn 中的损失函数来计算损失。注意在使用 torch.nn.functional 时,需要手动定义和传递权重和偏置。

6. 小结

torch.nn 和 torch.nn.functional 在定义神经网络组件、应用激活函数和计算损失等方面存在显著的区别。torch.nn 提供了一种面向对象的方式来构建模型,而 torch.nn.functional 则提供了一种更灵活、更函数式的方式来执行相同的操作。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/957198.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

海思Hi3516CV610 -----芯片说明

Hi3516CV610这颗超高清智慧视觉SoC芯片是由海思技术有限公司推出的,其首发量产的时间是在2024年4。标志着海思正式回归安防市场,并在IPC SoC市场中展开竞争。 关键特性 ●4K20,6M30分辨率 ●双目实时接入,支撑枪球一体机等双目机…

iOS-支付相关

支付宝支付 #import <AlipaySDK/AlipaySDK.h> //orderStrAliPay为服务端传的订单信息 //fromScheme为应用配置的schemeUrl标识&#xff0c;用户支付包支付成功后跳转会本应用内 //callback回调需要在- (BOOL)application:(UIApplication *)app openURL:(NSURL *)url 中调…

go读取excel游戏配置

1.背景 游戏服务器&#xff0c;配置数据一般采用csv/excel来作为载体&#xff0c;这种方式&#xff0c;策划同学配置方便&#xff0c;服务器解析也方便。在jforgame框架里&#xff0c;我们使用以下的excel配置格式。 然后可以非常方便的进行数据检索&#xff0c;例如&#xff…

热点营销新视角:品牌如何巧妙“跟风”不踩雷?

在当今信息爆炸的时代&#xff0c;热点事件如同流星划过夜空&#xff0c;转瞬即逝&#xff0c;却总能吸引无数眼球。而软文营销作为品牌推广的重要手段之一&#xff0c;如何巧妙“跟风”热点&#xff0c;既能借势提升品牌曝光度&#xff0c;又能避免陷入雷区&#xff0c;成为众…

Trimble三维激光扫描-地下公共设施维护的新途径【沪敖3D】

三维激光扫描技术生成了复杂隧道网络的高度详细的三维模型 项目背景 纽约州北部的地下通道网络已有100年历史&#xff0c;其中包含供暖系统、电线和其他公用设施&#xff0c;现在已经开始显露出老化迹象。由于安全原因&#xff0c;第三方的进入受到限制&#xff0c;在没有现成纸…

TDengine 做 Apache SuperSet 数据源

‌Apache Superset‌ 是一个现代的企业级商业智能&#xff08;BI&#xff09;Web 应用程序&#xff0c;主要用于数据探索和可视化。它由 Apache 软件基金会支持&#xff0c;是一个开源项目&#xff0c;它拥有活跃的社区和丰富的生态系统。Apache Superset 提供了直观的用户界面…

K8S-Pod的环境变量,重启策略,数据持久化,资源限制

1. Pod容器的三种重启策略 注意&#xff1a;k8s所谓的重启容器指的是重新创建容器 cat 07-restartPolicy.yaml apiVersion: v1 kind: Pod metadata:name: nginx-web-imagepullpolicy-always spec:nodeName: k8s233.oldboyedu.com## 当容器异常退出时&#xff0c;始终重启容器r…

03垃圾回收篇(D1_垃圾收集器算法底层导论)

目录 一、为什么我们要去了解垃圾收集和内存分配 二、对象已死&#xff1f; 1. 引用计数算法 2. 可达性分析算法 3. 再谈引用 4. 生存还是死亡 5. 回收方法区 三、垃圾收集算法 1. 简介 2. 分代收集理论 2.1. 弱分代/强分代假说 2.2. 前面两代假说的缺陷 3. 标记-清…

【wiki知识库】08.添加用户登录功能--后端SpringBoot部分

目录 一、今日目标? 二、SpringBoot后端实现 2.1 新增UserLoginParam 2.2 修改UserController 2.3 UserServiceImpl代码 2.4 创建用户上下文工具类 2.5?通过token校验用户&#xff08;重要&#xff09; 2.6 创建WebMvcConfig 2.7 用户权限校验拦截器 一、今日目标 上…

分布式 IO 模块:开启药品罐装产线高效生产新纪元

在药品生产的精密领域&#xff0c;每一个环节都关乎着客户的健康与安全。药品罐装产线作为药品生产的关键环节&#xff0c;其高效运行与精准控制至关重要。明达技术MR30分布式 IO 模块&#xff0c;正以其卓越的性能&#xff0c;成为实现药品罐装产线高效控制&#xff0c;确保产…

【北京迅为】iTOP-4412全能版使用手册-第八十七章 安装Android Studio

iTOP-4412全能版采用四核Cortex-A9&#xff0c;主频为1.4GHz-1.6GHz&#xff0c;配备S5M8767 电源管理&#xff0c;集成USB HUB,选用高品质板对板连接器稳定可靠&#xff0c;大厂生产&#xff0c;做工精良。接口一应俱全&#xff0c;开发更简单,搭载全网通4G、支持WIFI、蓝牙、…

基于海思soc的智能产品开发(视频的后续开发)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 前面我们讨论了camera&#xff0c;也讨论了屏幕驱动&#xff0c;这些都是基础的部分。关键是&#xff0c;我们拿到了这些视频数据之后&#xff0c;…

Linux -- HTTP 请求 与 响应 报文

目录 请求报文&#xff1a; 请求方法 响应报文&#xff1a; 状态码 与 状态码描述 共性 常见的报头 请求报文&#xff1a; 请求方法 方法说明GET获取资源POST传输实体主体PUT传输文件HEAD获得报文首部DELETE删除文件OPTIONS询问支持的方法TRACE追踪路径CONNECT要求用…

HTML<img>标签

例子 如何插入图片&#xff1a; <img src"img_girl.jpg" alt"Girl in a jacket" width"500" height"600"> 下面有更多“自己尝试”的示例。 定义和用法 该<img>标签用于在 HTML 页面中嵌入图像。 从技术上讲&#x…

C++ 面向对象(继承)

三、继承 3.1 继承的概念 基于一个已有的类 去重新定义一个新的类&#xff0c;这种方式我们叫做继承 关于继承的称呼 一个类B 继承来自 类 A 我们一般称呼 A类&#xff1a;父类 基类 B类: 子类 派生类 B继承自A A 派生了B 示例图的语法 class vehicle // 车类 {}class …

mfc操作json示例

首先下载cJSON,加入项目; 构建工程,如果出现, fatal error C1010: unexpected end of file while looking for precompiled head 在cJSON.c文件的头部加入#include "stdafx.h"; 看情况,可能是加到.h或者是.cpp文件的头部,它如果有包含头文件, #include &…

基于微信小程序的模拟考试系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导&#xff0c;欢迎高校老师/同行前辈交流合作✌。 技术范围&#xff1a;SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容&#xff1a;…

Android10.0定制服务 APK安装或者更新过自动打开APK

需求描述 当用户APK安装|更新 成功之后自动打开APK。 实现思路 编写服务 注册接受安装广播 并判断安装的APK包名是否为客户APK包名,并自动打开。 具体实现 diff --git a/android/build/make/core/tasks/check_boot_jars/package_whitelist.txt b/android/build/make/cor…

linux 下tensorrt的yolov8的前向推理(python 版本)的实现

一、yolov8的python实现的环境搭建 #通过pip安装 pip install ultralytics #通过git克隆GitHub仓库 git clone <https://github.com/ultralytics/ultralytics.git> cd ultralytics #安装依赖 pip install -r requirements.txt #执行推理 yolo predict model./yolov8n.pt …

AI News(1/21/2025):OpenAI 安全疏忽:ChatGPT漏洞引发DDoS风险/OpenAI 代理工具即将发布

1、OpenAI 的安全疏忽&#xff1a;ChatGPT API 漏洞引发DDoS风险 德国安全研究员 Benjamin Flesch 发现了一个严重的安全漏洞&#xff1a;攻击者可以通过向 ChatGPT API 发送一个 HTTP 请求&#xff0c;利用 ChatGPT 的爬虫对目标网站发起 DDoS 攻击。该漏洞源于 OpenAI 在处理…