PyTorch使用教程(10)-torchinfo.summary网络结构可视化详细说明

1、基本介绍

torchinfo是一个为PyTorch用户量身定做的开源工具,其核心功能之一是summary函数。这个函数旨在简化模型的开发与调试流程,让模型架构一目了然。通过torchinfosummary函数,用户可以快速获取模型的详细结构和统计信息,如模型的层次结构、输入/输出维度、参数数量、多加操作(Mult-Adds)等关键信息。

2、安装

首先,你需要安装torchinfo库。可以通过pip进行安装:

pip install torchinfo

3、导入

安装完成后,需要在你的Python脚本中导入torchinfo模块:

from torchinfo import summary

4、函数原型定义

torchinfo的summary函数原型定义如下:

def summary(model: nn.Module, input_data: torch.Tensor | tuple[torch.Tensor, ...] | tuple[int, ...] | None = None, batch_dim: int = 0, col_widths: tuple[int, ...] | None = None, col_names: tuple[str, ...] | None = None, device: str | torch.device | None = None, dtypes: tuple[torch.dtype, ...] | None = None, verbose: int = 1, **kwargs)

参数说明

  • model: 要分析的PyTorch模型,必须是torch.nn.Module的实例。
  • input_data: 用于模型前向传播的输入数据。它可以是一个torch.Tensor对象,也可以是一个包含多个输入张量的元组。此外,还可以提供一个表示输入尺寸的元组,例如(batch_size, channels, height, width)。
  • batch_dim: 指定输入张量中哪个维度是批量大小(batch size)。默认为0。
  • col_widths: 指定输出列宽的元组。如果未指定,则自动计算列宽以适应输出。
  • col_names: 指定输出列名的元组。如果未指定,则使用默认列名。
  • device: 指定模型运行的设备(如’cpu’或’cuda’)。如果未指定,则自动选择。
  • dtypes: 指定输入张量的数据类型。如果未指定,则自动推断。
  • verbose: 控制输出信息的详细程度。默认为1,表示输出基本信息。设置为2或更高可以获得更详细的输出。
  • kwargs: 其他关键字参数,可以传递给模型的前向传播函数。

5、使用方法

下面通过几个示例来展示如何使用torchinfo的summary函数。
5.1 使用预定义模型
首先,我们使用PyTorch预定义的模型(如torchvision.models.resnet50)来展示如何使用summary函数。

import torch
import torchvision.models as models
from torchinfo import summary
# 定义模型
model = models.resnet18(pretrained=False)

# 使用summary函数打印模型概况
summary(model, input_size=(1, 3, 224, 224))

在这个示例中,我们加载了一个未预训练的ResNet50模型,并使用summary函数打印了模型的概况。input_size参数指定了输入数据的大小,即(batch_size, channels, height, width)。
在这里插入图片描述

5.2 使用自定义模型
接下来,我们定义一个简单的自定义模型,并使用summary函数打印其概况。

import torch
import torch.nn as nn
from torchinfo import summary

# 定义一个简单的两层全连接神经网络模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = SimpleModel()

# 使用summary函数打印模型概况
summary(model, input_size=(100,))

在这个示例中,我们定义了一个简单的两层全连接神经网络模型,并使用summary函数打印了模型的概况。input_size参数指定了输入数据的大小,即(batch_size, features)。由于我们的模型是一个全连接层,所以我们只指定了特征数量。
在这里插入图片描述

5.3 使用自定义输入数据

有时候,可能想要使用实际的输入数据来查看模型的概况。下面是一个示例,展示了如何使用自定义输入数据来调用summary函数。

import torch
import torchvision.models as models
from torchinfo import summary

# 定义模型
model = models.resnet50(pretrained=False)

# 创建自定义输入数据
input_data = torch.randn(1, 3, 224, 224)  # batch_size=1, channels=3, height=224, width=224

# 使用summary函数打印模型概况
summary(model, input_data=input_data)

在这个示例中,我们创建了一个形状为(1, 3, 224, 224)的随机张量作为输入数据,并使用summary函数打印了模型的概况。注意,这里我们使用input_data参数而不是input_size参数来指定输入数据。

5.4 调整输出格式
torchinfo允许通过col_widths和col_names参数来调整输出的格式。下面是一个示例,展示了如何自定义输出列宽和列名。

import torch
import torchvision.models as models
from torchinfo import summary

# 定义模型
model = models.resnet50(pretrained=False)

# 使用summary函数打印模型概况,并自定义输出列宽和列名
summary(model, input_size=(3, 224, 224), col_widths=(30, 30, 20, 20),
        col_names=('input_size', 'output_size', 'kernel_size', 'num_params'))

在这个示例中,我们自定义了输出列宽和列名。col_widths参数指定了每列的宽度(以字符为单位),而col_names参数指定了每列的列名。这样,就可以根据需要来调整输出的格式了。

6、小结

torchinfo的summary函数是一个强大的工具,可以方便地查看PyTorch模型的结构和参数数量。通过本文的介绍,应该已经掌握了如何使用summary函数来打印模型的概况。无论使用预定义模型还是自定义模型,无论是使用输入尺寸还是自定义输入数据,torchinfo都能提供详细而清晰的输出信息。希望这篇文章能对你有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/956988.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java模拟路由协议-rip(路由器仿真实验)

前言: 好久不见,有段时间没有写文章了,本篇文章,由Blue我带大家来复现rip协议。我们以 b站湖南教师匠所讲rip的视频中的例子为我这篇文章所模拟的路由路径 如图: 模拟路径 视频:http://【深入浅出计算机网络…

32V/4A,降压DCDC转换器CP8384百分百占空比输出ESOP8封装,可适用HUB等电路设计

特点: ● Supply Voltage Range: 4.1V~32V ● Input voltage up to 40V ● 4A Continuous Output Current ● Up to 95% Output Efficiency ● CC/CV control ● 350kHz Switching Frequency ● Built-in Soft Start ● 100% Maximum Duty Cycle ● No External Com…

缓存、数据库双写一致性解决方案

双写一致性问题的核心是确保数据库和缓存之间的数据同步,以避免缓存与数据库数据不同步的问题,尤其是在高并发和异步环境下。本文将探讨双写一致性面临的主要问题和解决方案,重点关注最终一致性。 本文讨论的是最终一致性问题 双写一致性面…

【学习笔记15】如何在非root服务器中,安装属于自己的redis

一、下载安装包 官网下载黑马程序员给的安装包(redis-6.2.6) 二、将安装包上传至服务器 我将安装包上传在我的文件夹/home/XXX,指定路径中/src/local/redis/,绝对路径为/home/XXX/src/local/redis/解压安装包 XXXomega:~$ cd …

计算机网络 (51)鉴别

前言 计算机网络鉴别是信息安全领域中的一项关键技术,主要用于验证用户或信息的真实性,以及确保信息的完整性和来源的可靠性。 一、目的与重要性 鉴别的目的是验明用户或信息的正身,对实体声称的身份进行唯一识别,以便验证其访问请…

【大模型】ChatGPT 高效处理图片技巧使用详解

目录 一、前言 二、ChatGPT 4 图片处理介绍 2.1 ChatGPT 4 图片处理概述 2.1.1 图像识别与分类 2.1.2 图像搜索 2.1.3 图像生成 2.1.4 多模态理解 2.1.5 细粒度图像识别 2.1.6 生成式图像任务处理 2.1.7 图像与文本互动 2.2 ChatGPT 4 图片处理应用场景 三、文生图操…

后端:MyBatis

文章目录 1. MyBatis1-1. Mybatis 工具类的封装1-2. Mybatis 通过集合或实体类传递参数-实现插入数据(增)1-3. MyBatis 实现删除数据(删)1-4. MyBatis 实现修改数据(改)1-5. MyBatis 实现查询数据(查) 2. MyBatis 配置文件中的一些标签和属性2-1.environments标签2-2. dataSour…

将 AzureBlob 的日志通过 Azure Event Hubs 发给 Elasticsearch(1.标准版)

问题 项目里使用了 AzureBlob 存储了用户上传的各种资源文件,近期 AzureBlob 的流量费用增长很快,想通过分析Blob的日志,获取一些可用的信息,所以有了这个需求:将存储账户的日志(读写,审计&…

数字化时代,传统代理模式的变革之路

在数字化飞速发展的今天,线上线下融合(O2O)成了商业领域的大趋势。这股潮流,正猛烈冲击着传统代理模式,给它带来了新的改变。 咱们先看看线上线下融合现在啥情况。线上渠道那是越来越多,企业纷纷在电商平台…

【AI | pytorch】torch.polar的使用

一、torch.polar的使用 torch.polar 是 PyTorch 中用来生成复数张量的一个函数,但它与数学中的复数表达式 ( z re^{i\theta} ) 是等价的。 具体来说,torch.polar(abs, angle) 接受两个实数张量参数: abs:表示复数的模长&#…

LeetCode 110.平衡二叉树

题目描述 给定一个二叉树,判断它是否是平衡二叉树。 示例 1: 示例 2: 输入:root [1,2,2,3,3,null,null,4,4] 输出:false 示例 3: 输入:root [] 输出:true 提示: …

数据结构(Java版)第四期:ArrayLIst和顺序表(上)

目录 一、顺序表 1.1. 接口的实现 二、ArrayList简介 2.1. ArrayList的构造 2.2. ArrayList的常见操作 2.3. ArrayList的扩容机制 三、ArrayList的具体使用 3.1. 洗牌算法 3.2. 杨辉三角 一、顺序表 上一期我们讲到过,顺序表本质上和数组是差不多的&#…

阿里云 Serverless 助力盟主直播:高并发下的稳定性和成本优化

在直播场景中,阿里云 Serverless 应用引擎 SAE 提供的无缝弹性伸缩与极速部署能力,确保直播间高并发时的流畅体验,降低了我们的运营成本,简化了运维流程。结合阿里云云原生数据库 PolarDB 的 Serverless 能力,实现了数…

【机器学习实战入门】基于深度学习的乳腺癌分类

什么是深度学习? 作为对机器学习的一种深入方法,深度学习受到了人类大脑和其生物神经网络的启发。它包括深层神经网络、递归神经网络、卷积神经网络和深度信念网络等架构,这些架构由多层组成,数据必须通过这些层才能最终产生输出。…

Qt之QDjango-db的简单使用

QDjango是一款由C编写、依托于Qt库的Web开发框架,其设计理念受到了广受欢迎的Python框架Django的影响。这个项目旨在提供一个高效、灵活且易于使用的工具集,帮助开发者构建高质量的Web应用。其项目地址: https://gitcode.com/gh_mirrors/qd/qdjango&…

[2025分类时序异常检测指标R-AUC与VUS]

梳理了一下分类中常见的指标,这些指标与时序异常检测中新提出的A-RUC与VUS之间的关系 真正例(True Positive,TP): 被正确识别为正样本的数量。真负例(True Negative,TN): 被正确识别为负样本的数量。假正例(False Positive ,FP): 被错误识为正样本数量假负例(Fals…

python3GUI--仿崩坏三二次元登录页面(附下载地址) By:PyQt5

文章目录 一.前言二.预览三.实现方案1.实现原理1.PyQt52. 具体实现 2.UI设计1.UI组件化、模块化2.UI设计风格思路 3.项目代码结构4.使用方法3.代码分享1.支持跳转网页的QLabel组件2.三角形ICON按钮 四.总结 大小:33.3 …

STM32 FreeRTOS中断管理

目录 FreeRTOS的中断管理 1、STM32中断优先级管理 2、FreeRTOS任务优先级管理 3、寄存器和内存映射寄存器 4、BASEPRI寄存器 5、FreeRTOS与STM32中断管理结合使用 vPortRaiseBASEPRI vPortSetBASEPRI 6、FromISR后缀 7、在中断服务函数中调用FreeRTOS的API函数需注意 F…

如何在idea中搭建SpringBoot项目

如何在idea中快速搭建SpringBoot项目 目录 如何在idea中快速搭建SpringBoot项目前言一、环境准备:搭建前的精心布局 1.下载jdk (1)安装JDK:(2)运行安装程序:(3)设置安装…

Linux:expect spawn简介与用法

一、背景 大家在使用linux系统的很多时候,都用linux指令来实现一些操作,执行特定的job,有时一些场景中需要执行交互指令来完成任务,比如ssh登录这个命令大家一定很熟悉: ssh-keygen -t rsa # 以及 ssh-copy-id -i /hom…