【Pytorch】六行代码实现:特征图提取与特征图可视化

前言

之前记录过特征图的可视化:Pytorch实现特征图可视化,当时是利用IntermediateLayerGetter 实现的,但是有很大缺陷,只能获取到一级的子模块的特征图输出,无法获取内部二级子模块的输出。今天补充另一种Pytorch官方实现好的特征提取方式,非常好用!


特征图提取

  • 前言
  • 一、Torch FX
  • 二、特征提取
    • 1.使用get_graph_node_names提取各个节点
    • 2.使用create_feature_extractor提取输出
    • 3.六行代码可视化特征图
  • 三、Reference


一、Torch FX

首先是Torch FX的介绍:FX Blog(具体可参考Reference)

FX based feature extraction is a new TorchVision utility that lets us access intermediate transformations of an input during the forward pass of a PyTorch Module. It does so by symbolically tracing the forward method to produce a graph where each node represents a single operation. Nodes are named in a human-readable manner such that one may easily specify which nodes they want to access.
Did that all sound a little complicated? Not to worry as there’s a little in this article for everyone. Whether you’re a beginner or an advanced deep-vision practitioner, chances are you will want to know about FX feature extraction. If you still want more background on feature extraction in general, read on. If you’re already comfortable with that and want to know how to do it in PyTorch, skim ahead to Existing Methods in PyTorch: Pros and Cons. And if you already know about the challenges of doing feature extraction in PyTorch, feel free to skim forward to FX to The Rescue.


也就是我们后面调用的特征提取函数是基于Torch FX实现的。总之一句话:基于FX的特征提取是一种新的TorchVision实用程序,它允许我们在PyTorch模块的前向传递过程中访问输入的中间值。


二、特征提取

1.使用get_graph_node_names提取各个节点

首先依然是查看各个网络的子层

#首先定义一个模型,这里直接加载models里的预训练模型
model = torchvision.models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#查看模型的各个层,
for name in model.named_children():
    print(name[0])
#输出,相当于把ResNet的分成了10个层
"""
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc"""

在这里插入图片描述


之前是利用IntermediateLayerGetter 实现的,但是有很大缺陷,只能获取到一级的子模块的特征图输出,无法获取内部二级子模块的输出。比如不能获取layer2内部第一个BasicBlock的特征图输出。现在可以利用 get_graph_node_names获取任意前向传播的子节点。

import torchvision
import torch
from torchvision.models.feature_extraction import get_graph_node_names

model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
nodes, _ = get_graph_node_names(model)
nodes
# 输出如下
"""
['x',
 'conv1',
 'bn1',
 'relu',
 'maxpool',
 'layer1.0.conv1',
 'layer1.0.bn1',
 'layer1.0.relu',
 'layer1.0.conv2',
 'layer1.0.bn2',
 'layer1.0.add',
 'layer1.0.relu_1',
 'layer1.1.conv1',
 'layer1.1.bn1',
 'layer1.1.relu',
 'layer1.1.conv2',
 'layer1.1.bn2',
 'layer1.1.add',
 'layer1.1.relu_1',
 'layer2.0.conv1',
 'layer2.0.bn1',
 'layer2.0.relu',
 'layer2.0.conv2',
 'layer2.0.bn2',
 'layer2.0.downsample.0',
 'layer2.0.downsample.1',
 'layer2.0.add',
 'layer2.0.relu_1',
 'layer2.1.conv1',
 'layer2.1.bn1',
 'layer2.1.relu',
 'layer2.1.conv2',
 'layer2.1.bn2',
 'layer2.1.add',
 'layer2.1.relu_1',
 'layer3.0.conv1',
 'layer3.0.bn1',
 'layer3.0.relu',
 'layer3.0.conv2',
 'layer3.0.bn2',
 'layer3.0.downsample.0',
 'layer3.0.downsample.1',
 'layer3.0.add',
 'layer3.0.relu_1',
 'layer3.1.conv1',
 'layer3.1.bn1',
 'layer3.1.relu',
 'layer3.1.conv2',
 'layer3.1.bn2',
 'layer3.1.add',
 'layer3.1.relu_1',
 'layer4.0.conv1',
 'layer4.0.bn1',
 'layer4.0.relu',
 'layer4.0.conv2',
 'layer4.0.bn2',
 'layer4.0.downsample.0',
 'layer4.0.downsample.1',
 'layer4.0.add',
 'layer4.0.relu_1',
 'layer4.1.conv1',
 'layer4.1.bn1',
 'layer4.1.relu',
 'layer4.1.conv2',
 'layer4.1.bn2',
 'layer4.1.add',
 'layer4.1.relu_1',
 'avgpool',
 'flatten',
 'fc']
"""

get_graph_node_names把前向传播的各个节点都列出来了形成了一个列表。比如列表中的x表示我们的输入;layer1.0.conv2表示layer1的第1个BasicBlock的conv2节点;layer3.1.conv2表示layer3的第2个BasicBlock的conv2节点;这些节点和我们上图方框中圈出来的是一一对应的,可以结合自己的网络结构具体分析。

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 4, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))


        self.fc=nn.Sequential(nn.Linear(256*6*6, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 100),
                                )

    def forward(self, x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        output=self.fc(x.view(-1, 256*6*6))
        return output
    
model=AlexNet()
nodes, _ = get_graph_node_names(model)
nodes
# 输出如下
['x',
 'conv1.0',
 'conv1.1',
 'conv1.2',
 'conv2.0',
 'conv2.1',
 'conv2.2',
 'conv3.0',
 'conv3.1',
 'conv3.2',
 'conv3.3',
 'conv3.4',
 'conv3.5',
 'conv3.6',
 'view',
 'fc.0',
 'fc.1',
 'fc.2',
 'fc.3',
 'fc.4',
 'fc.5',
 'fc.6']

如果是自定义网络结构,在__init__中初始化了self.conv1self.conv2self.conv3self.fc与输出列表相对应。
conv3为例:

 self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))

总共定义了7层,3个卷积层、3个激活层、1个池化层。 输出节点列表中的conv3.0就表示conv3的第一个节点即第一个卷积层nn.Conv2d(256, 384, 3, 1, 1),同理, conv3.1表示conv3的第二个节点即nn.ReLU()

2.使用create_feature_extractor提取输出

在获取节点信息之后,我么可以利用create_feature_extractor来获取对应节点层的输出。所以get_graph_node_names只是帮助我们获取节点层的信息。

比如,我只想获取layer3layer4内部的第一个卷积层的输出即layer3.0.conv1, layer4.0.conv1

import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

# 根据get_graph_node_names得到的节点层信息
# 定义想要得到的输出层
features = ['layer3.0.conv1', "layer4.0.conv1"]

model = torchvision.models.resnet18(
					weights=torchvision.models.ResNet18_Weights.DEFAULT)
					
# return_nodes参数就是返回对应的输出
feature_extractor = create_feature_extractor(model, return_nodes=features)
# 定义输入
x=torch.ones(1, 3, 224, 224)
# 得到一个我们想要的输出层的字典
out = feature_extractor(x)
out

# tensor即对应的输出
"""
{'layer3.0.conv1': tensor(...),
 'layer4.0.conv1': tensor(...) }
"""

当然,并不是一定要完全按照get_graph_node_names得到的节点层信息来定义输出层。比如,我只想获取layer3整个层的输出特征图,我并不关心layer3内部子层的输出:

import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

# 定义layer3即可
# 其他层同理
features = ['layer3']
model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes=features)
# 定义输入
x=torch.ones(1, 3, 224, 224)
# 得到一个我们想要的输出层的字典
out = feature_extractor(x)
out
"""
{'layer3': tensor(...)}
"""


return_nodes参数也可以传入一个字典,字典的键是节点层,值是自定义别名。比如{"layer3":"output1","layer4":"output2"}

features = {"layer3":"output1","layer4":"output2"}
model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes=features)
x=torch.ones(1, 3, 224, 224)
out = feature_extractor(x)
out
# 输出如下
"""
{'output1': tensor(...),
 'output2': tensor(...)}

"""

3.六行代码可视化特征图

import torch
import torchvision
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractor


transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 

model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)

feature_extractor = create_feature_extractor(model, return_nodes={"conv1":"output"})

original_img = Image.open("dog.jpg")

img=transform(original_img).unsqueeze(0)

out = feature_extractor(img) 

# 这里没有分通道可视化
plt.imshow(out["output"][0].transpose(0,1).sum(1).detach().numpy())

在这里插入图片描述

在这里插入图片描述

三、Reference

Torch FX官方文档:Torch FX官方文档介绍
Torch FX Blog:Feature Extraction in TorchVision using Torch FX
在这里插入图片描述
官方对四种获取特征输出的方式进行了对比,这篇Blog写的比较详细,可以仔细看看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/14807.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数字孪生新能源智慧充电桩Web3D可视化运维系统

放眼全球,近十年来,新能源汽车赛道堪称“热得发烫”。伴随着进入成年期的新能源汽车行业对相关配套设备支撑水平的提升,作为其“新基建”的充电桩领域表现更为突出的价值势能。过去,在一系列补贴政策和资本刺激下,充电…

插装式两位两通电磁阀DSV-080-2NCP、DDSV-080-2NCP

特性 压力4000 PSI(276 Bar) 持续的电磁。 硬化处理的提升阀和柱塞可获得更长的寿命和低泄漏量。 有效的混式电磁铁结构。 插装阀允许交流电压。可选的线圈电压和端子。 标准的滤网低泄漏量选择 手动关闭选择。 工业化通用阀腔。 紧凑的尺寸。 两位两通常闭式双向电磁…

vue element-ui web端 引入高德地图,并获取经纬度

发版前接到一个临时新需求 ,需要在web端地址选择时用地图,并获取经纬度。 临阵发版之际加需求,真的是很头疼,于是赶紧找度娘,找api。 我引入的是高德地图,首先要去申请key , 和密钥,…

在安装docker配置端口时 centos7 防火墙规则失效

一、问题 1、做端口映射管理的时候,自己关闭了防火墙,或者开启防火墙,或者指定开关端口,但是都不影响端口的使用,这就很奇怪,也就是本文的内容! 2、思路,确认是请求到了防火墙的那…

老板们搞怪营业,品牌好感度upup真有梗

老板下场营业最经典的莫过于“老乡鸡”了。在手撕联名信事件出圈后,老乡鸡围绕束从轩创始人IP,开展了一系列社交传播宣传,比如“咯咯哒糊弄学”等。 50多岁的老乡鸡董事长束从轩,一改传统企业家严肃正经的形象,跟着老乡…

Windows下virtualbox相关软件安装设置全过程

一、下载 virtual box 程序 virtual box扩展程序-Oracle_VM_VirtualBox_Extension_Pack-7.0.8.vbox-extpack Virtualbox GuestAdditions 程序-解决分辨率,主机虚拟机之间共享文件、剪贴板等问题 http://download.virtualbox.org/virtualbox/7.0.8/ 或者 virtual b…

【shell脚本】条件语句

一、条件测试操作 1.1test命令与 [ ] 符号 测试表达试是否成立,若成立返回0,否则返回其它数值 1.1.1文件测试常用的测试操作符 符号作用-d测试是否为目录-e测试是否为目录或文件-f测试是否为文件-r测试当前用户是否有读取权限-w测试当前用户是否有写…

你掌握了stream流的全部新特性吗?

我们知道很早之前java8对于之前的版本更新了许多 新的支持,比如lamda函数式接口的支持,支持更多函数式接口的使用,对链表,数组,队列,集合等实现了Collectio接口的数据结构提供了StreamSupport.stream()支持…

运维监控工具PIGOSS BSM扩展指标介绍

PIGOSS BSM运维监控工具,除系统自带指标外,还支持添加SNMP扩展指标、脚本扩展指标、JMX扩展指标、自定义JDBC指标等,今天本文将介绍如何添加SNMP扩展指标和脚本扩展指标。 添加SNMP扩展指标 前提:需要知道指标的oid 例子&#xff…

如何实现Spring AOP以及Spring AOP的实现原理

AOP:面向切面编程,它和OOP(面向对象编程)类似。 AOP组成: 1、切面:定义AOP是针对那个统一的功能的,这个功能就叫做一个切面,比如用户登录功能或方法的统计日志,他们就各种是一个切面。切面是有切点加通知组成的。 2、连接点:所有可…

Redis入门学习笔记【一】

目录 一、Redis是什么 二、Redis数据结构 2.1 Redis 的五种基本数据类型 2.1.1String(字符串) 2.1.2字符串列表(lists) 2.1.3字符串集合(sets) 2.1.5哈希(hashes) 2.2 Red…

MQTT协议 详解

文章目录 一、啥是MQTT?1. MQTT协议特点2. 发布和订阅3. QoS(Quality of Service levels)QoS 0 —— 最多1次QoS 1 —— 最少1次QoS 2 —— 只有1次 二、MQTT 数据包结构1. MQTT固定头2. MQTT可变头 / Variable header3. Payload消息体 三、M…

Ceph入门到精通- storcli安装

storcli 是LSI公司官方提供的Raid卡管理工具,storcli已经基本代替了megacli,是一款比较简单易用的小工具。将命令写成一个个的小脚本,会将使用变得更方便。 安装简单,Windows系统下解压出来以后可以直接运行。 Linux系统默认位置…

Android程序员向音视频进阶,有前景吗

随着移动互联网的普及和发展,Android开发成为了很多人的就业选择,希望在这个行业能获得自己的一席之地。然而,随着时间的推移,越来越多的人进入到了Android开发行业,就导致目前Android开发的工作越来越难找&#xff0c…

7.Shuffle详解

1.分区规则 ps."&"指的是按位与运算,可以强制转换为正数 ps."%",假设reduceTask的个数为3,则余数为0,1,2正好指代了三个分区 以上代码的含义就是对key的hash值强制取正之后,对reduce的个数取…

大数据技术之Kafka集成

一、集成Flume 1.1 Flume生产者 (1)启动Kafka集群 zkServer.sh startnohup kafka-server-start.sh /opt/soft/kafka212/config/server.properties & (2)启动Kafka消费者 kafka-console-consumer.sh --bootstrap-server 192…

动态内存管理

文章目录 1.动态内存函数1.1free1.2malloc1.3calloc1.4realloc 2.动态内存错误2.1解引用空指针--非法访问内存2.2越界访问动态空间2.3free释放非动态空间2.4free释放部分动态空间2.5free多次释放动态空间2.6未释放动态内存 3.动态内存题目3.1形参不影响实参3.2地址返回&#xf…

APP渗透—查脱壳、反编译、重打包签名

APP渗透—查脱壳、反编译、重打包签名 1. 前言1.1. 其它 2. 安装工具2.1. 下载jadx工具2.1.1. 下载链接2.1.2. 执行文件 2.2. 下载apktool工具2.2.1. 下载链接2.2.2. 测试 2.3. 下载dex2jar工具2.3.1. 下载链接 3. 查壳脱壳3.1. 查壳3.1.1. 探探查壳3.1.2. 棋牌查壳 3.2. 脱壳3…

FVM初启,Filecoin生态爆发着力点在哪?

Filecoin 小高潮 2023年初,Filecoin发文分享了今年的三项重大变更,分别是FVM、数据计算和检索市场的更新,这些更新消息在发布后迅速吸引了市场的广泛关注。 特别是在3月14日,Filecoin正式推出了FVM,这一变革使得Filec…

多维时序 | MATLAB实现BO-CNN-GRU贝叶斯优化卷积门控循环单元多变量时间序列预测

多维时序 | MATLAB实现BO-CNN-GRU贝叶斯优化卷积门控循环单元多变量时间序列预测 目录 多维时序 | MATLAB实现BO-CNN-GRU贝叶斯优化卷积门控循环单元多变量时间序列预测效果一览基本介绍模型描述程序设计参考资料 效果一览 基本介绍 基于贝叶斯(bayes)优化卷积神经网络-门控循环…