PyTorch小技巧:使用Hook可视化网络层激活(各层输出)

这篇文章将演示如何可视化PyTorch激活层。可视化激活,即模型内各层的输出,对于理解深度神经网络如何处理视觉信息至关重要,这有助于诊断模型行为并激发改进。

我们先安装必要的库:

 pip install torch torchvision matplotlib

加载CIFAR-10数据集并可视化一些图像。这有助于理解模型处理的输入。

 importtorchvision
 importtorchvision.transformsastransforms
 importmatplotlib.pyplotasplt
 
 # Transformations for the images
 transform=transforms.Compose([
     transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 ])
 
 # Load CIFAR-10 dataset
 trainset=torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
 trainloader=torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
 
 # Function to show images
 defimshow(img):
     img=img.numpy().transpose((1, 2, 0))
     mean=np.array([0.485, 0.456, 0.406])
     std=np.array([0.229, 0.224, 0.225])
     img=std*img+mean  # unnormalize
     plt.imshow(img)
     plt.show()
 
 # Get some images
 dataiter=iter(trainloader)
 images, labels=next(dataiter)
 
 # Display images
 imshow(torchvision.utils.make_grid(images))

看着很模糊的原因是我们使用的CIFAR-10图像32x32的,很小 。因为对于小图像,处理速度很快,所以CIFAR-10称为研究的首选。

然后我们加载一个预训练的ResNet模型,并在特定的层上设置钩子函数,以在向前传递期间捕获激活。

 import torch
 from torchvision.models import resnet18
 
 # Load pretrained ResNet18
 model = resnet18(pretrained=True)
 model.eval()  # Set the model to evaluation mode
 
 # Hook setup
 activations = {}
 def get_activation(name):
     def hook(model, input, output):
         activations[name] = output.detach()
     return hook
 
 # Register hooks
 model.layer1[0].conv1.register_forward_hook(get_activation('layer1_0_conv1'))
 model.layer4[0].conv1.register_forward_hook(get_activation('layer4_0_conv1'))

这样,在通过模型处理图像时就能捕获到激活。

 # Run the model
 with torch.no_grad():
     output = model(images)

通过上面钩子函数我们获得了激活下面就可以进行可视化

 # Visualization function for activations
 def plot_activations(layer, num_cols=4, num_activations=16):
     num_kernels = layer.shape[1]
     fig, axes = plt.subplots(nrows=(num_activations + num_cols - 1) // num_cols, ncols=num_cols, figsize=(12, 12))
     for i, ax in enumerate(axes.flat):
         if i < num_kernels:
             ax.imshow(layer[0, i].cpu().numpy(), cmap='twilight')
             ax.axis('off')
     plt.tight_layout()
     plt.show()
 # Display a subset of activations
 plot_activations(activations['layer1_0_conv1'], num_cols=4, num_activations=16)

结果如下:

 plot_activations(activations['layer4_0_conv1'], num_cols=4, num_activations=16)

PyTorch的钩子函数(hooks)是一种非常有用的特性,它们允许你在训练的前向传播和反向传播过程中插入自定义操作。这对于调试、修改梯度或者理解网络的内部运作非常有帮助。

利用 PyTorch 钩子函数来可视化网络中的激活是一种很好的方式,尤其是想要理解不同层如何响应不同输入的情况下。在这个过程中,我们可以捕捉到网络各层的输出,并将其可视化以获得直观的理解。

可视化激活有助于理解卷积神经网络中的各个层如何响应输入图像中的不同特征。通过可视化不同的层,可以评估早期层是否捕获边缘和纹理等基本特征,而较深的层是否捕获更复杂的特征。这些知识对于诊断问题、调整层架构和改进整体模型性能是非常宝贵的。

https://avoid.overfit.cn/post/c63b9b1130fe425ea5b7d0bedf209b2e

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/551117.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

淘宝扭蛋机小程序开发:开启购物娱乐新纪元

在数字时代浪潮的推动下&#xff0c;小程序作为新兴的交互平台&#xff0c;正在不断引领着购物方式的革新。淘宝扭蛋机小程序的开发&#xff0c;便是这一变革中的一颗璀璨明星&#xff0c;它将传统扭蛋机的趣味与电商购物的便捷完美融合&#xff0c;为用户带来了前所未有的购物…

微服务整合Spring Cloud Gateway动态路由

前置 创建 Spring Cloud项目 参考&#xff1a;创建Spring Cloud Maven工程-CSDN博客 1. 创建一个maven jar类型项目 在idea中右键父工程-》New-》Module 创建一个maven工程 2. 引入相关依赖 在POM文件中引入下面的依赖 <project xmlns"http://maven.apache.org/P…

【C++】力扣OJ题:构建杨辉三角

Hello everybody!今天给大家介绍一道我认为比较经典的编程练习题&#xff0c;之所以介绍它是因为这道题涉及到二维数组的构建&#xff0c;如果用C语言动态构建二维数组是比较麻烦的&#xff0c;而用C中STL的vector<vector<int>>,就可以立马构建出来&#xff0c;这也…

本地生活服务平台都有哪些,靠谱吗?

随着本地生活服务的发展潜力和盈利方式被不断挖掘&#xff0c;越来越多的人开始发现其中所蕴含着的巨大商机&#xff0c;大家所熟悉的抖音、小红书和支付宝等平台也纷纷上线了本地生活板块&#xff0c;再次印证了其前景的广阔。在此背景下&#xff0c;普通人想要趁势入局分一杯…

前端开发攻略---从源码角度分析Vue3的Propy比Vue2的defineproperty到底好在哪里。一篇文章让你彻底弄懂响应式原理。

1、思考 Vue的响应式到底要干什么&#xff1f; 无非就是要知道当你读取对象的时候&#xff0c;要知道它读了。要做一些别的事情无非就是要知道当你修改对象的时候&#xff0c;要知道它改了。要做一些别的事情所以要想一个办法&#xff0c;把读取和修改的动作变成一个函数&#…

Python语言在地球科学领域中的实践技术应用教程

原文链接&#xff1a;Python语言在地球科学领域中的实践技术应用教程https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247601506&idx6&snee03d6abdbcbe0216a07340c1b49cb8d&chksmfa820c85cdf585934051f111b1d95877ba54e71cbe74be87f9e0bd07b8d46d97bbc7d…

深入理解大语言模型微调技术

一、概念解析 1、什么是微调&#xff08;Fine-tuning&#xff09;&#xff1f; 大模型微调&#xff0c;也称为Fine-tuning&#xff0c;是指在已经预训练好的大型语言模型基础上&#xff08;一般称为“基座模型”&#xff09;&#xff0c;使用特定的数据集进行进一步的训练&am…

基于Springboot的某大药房管理系统

开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven…

小成本搏大流量:微信/支付宝小程序搜索排名优化

随着移动互联网的快速发展&#xff0c;小程序已成为企业和个人开发者重要的流量入口和业务承载平台。而小程序搜索排名则是影响小程序曝光量、用户获取及业务转化的关键因素。小柚在本文和大家探讨如何制定有效的优化方案&#xff0c;提升小程序在搜索结果中的排名。 首先跟我…

【图文教程】在PyCharm中导入Conda环境

文章目录 &#xff08;1&#xff09;在Anaconda Prompt中新建一个conda虚拟环境&#xff08;2&#xff09;使用PyCharm打开需要搭建环境的项目&#xff08;3&#xff09;配置环境 &#xff08;1&#xff09;在Anaconda Prompt中新建一个conda虚拟环境 conda create - myenv py…

Python SQL解析和转换库之sqlglot使用详解

概要 Python SQLGlot是一个基于Python的SQL解析和转换库,可以帮助开发者更加灵活地处理和操作SQL语句。本文将介绍SQLGlot库的安装、特性、基本功能、高级功能、实际应用场景等方面。 安装 安装SQLGlot库非常简单,可以使用pip命令进行安装: pip install sqlglot安装完成后…

Jenkins打包app并通过openssh上传到服务器

1、下载安装openssh 网上很多教程&#xff0c;包括开端口的&#xff0c;可以搜下 2、配置openssh根目录 进入C:\ProgramData\ssh打开文件sshd_config&#xff0c;添加配置ChrootDirectory D:\wxs\soft&#xff0c;想改端口的也在这个文件 3、安装Jenkins 参考上一篇 4、新…

WordPress采集插件大比拼:哪款才是站长的救星?

本着节约站长宝贵时间的理念&#xff0c;WordPress网站内容管理系统应采取自动化采集技术。因此催生了各种相应的WordPress提取插件和软件。然而&#xff0c;在众多可用选项面前&#xff0c;如何做出最佳选择成为关键问题。权衡的要素包括功能可靠性、易用性、稳定性以及定制程…

Java Bean 通用方法自动生成

原文&#xff1a;https://blog.iyatt.com/?p14637 使用 Lombok&#xff1a;https://mvnrepository.com/artifact/org.projectlombok/lombok 写了一个 Person 类&#xff0c;通过 Lombok 就可以生成通用的方法 package com.iyatt;import lombok.AllArgsConstructor; import l…

Linux系统中LVM与磁盘配额

目录 一、LVM逻辑卷管理 二、LVM的管理命令 物理卷管理 卷组管理 逻辑卷管理 *创建并使用LVM步骤 三、磁盘配额概述 实现磁盘限额的条件 Linux 磁盘限额的特点 四、磁盘配额管理 磁盘限额 一、LVM逻辑卷管理 能够在保持现有数据不变的情况下动态调整磁盘容量&#…

如何解决SSL证书不生效,有免费SSL证书吗?

SSL&#xff08;Secure Sockets Layer&#xff09;证书起着举足轻重的作用。它为网站提供加密服务&#xff0c;从而确保用户数据在传输过程中的安全性。然而&#xff0c;有时我们可能会遇到SSL证书不生效的问题&#xff0c;这不仅会降低网站的信任度&#xff0c;还可能导致数据…

支付宝支付之SpringBoot整合支付宝入门

支付宝支付 对接流程 申请阿里支付官方企业账号配置应用签约产品获取RSAKey&#xff08;非对称加密&#xff09;必须获得两个加密串&#xff1a;一个公钥&#xff0c;一个密钥SDK功能开发业务对接支付回调支付组件 核心所需的参数 APPID商家私钥支付宝公钥支付回调地址网关…

Python使用pymssql连接 SQLServer2008 报错:DB-Lib error message 20002, severity 9

Python使用pymssql连接 SQLServer2012没有问题&#xff0c;但是连接SQLServer2008就会报错DB-Lib error message 20002, severity 9&#xff0c;问题解决 可以打印详细连接过程的方式&#xff1a; import pymssql import os os.environ[TDSDUMP] stdout # 用于打印连接详细过…

安全认证Kerberos详解

文章目录 一、Kerberos入门与使用1、Kerberos概述1.1 什么是Kerberos1.2 Kerberos术语1.3 Kerberos认证原理 2、Kerberos安装2.1 安装Kerberos相关服务2.2 修改配置文件2.3 其他配置与启动 3、Kerberos使用概述3.1 Kerberos数据库操作3.2 Kerberos认证操作 二、Hadoop Kerberos…

【随笔】Git 高级篇 -- 远程与本地不一致导致提交冲突 git push --rebase(三十一)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…