使用grad-cam对ViT的输出进行可视化

使用grad-cam对ViT的输出进行可视化

文章目录

  • 使用grad-cam对ViT的输出进行可视化
    • 前言
    • 原理
    • 使用代码
    • Pytorch-grad-cam库的更多方法
    • 在MMpretrain中使用
    • 示例
    • 总结

前言

Vision Transformer (ViT) 作为现在CV中的主流backbone,它可以在图像分类任务上达到与卷积神经网络 (CNN) 相媲美甚至超越的性能。ViT 的核心思想是将输入图像划分为多个小块,然后将每个小块作为一个 token 输入到 Transformer 的编码器中,最终得到一个全局的类别 token 作为分类结果。

ViT 的优势在于它可以更好地捕捉图像中的长距离依赖关系,而不需要使用复杂的卷积操作。然而,这也带来了一个挑战,那就是如何解释 ViT 的决策过程,以及它是如何关注图像中的不同区域的。为了解决这个问题,我们可以使用一种叫做 grad-cam 的技术,它可以根据 ViT 的输出和梯度,生成一张热力图,显示 ViT 在做出分类时最关注的图像区域。

原理

grad-cam对ViT的输出进行可视化的原理是利用 ViT 的最后一个注意力块的输出和梯度,计算出每个 token 对分类结果的贡献度,然后将这些贡献度映射回原始图像的空间位置,形成一张热力图。具体来说,grad-cam+ViT 的步骤如下:

  1. 给定一个输入图像和一个目标类别,将图像划分为 14x14 个小块,并将每个小块转换为一个 768 维的向量。在这些向量之前,还要加上一个特殊的类别 token ,用于表示全局的分类信息。这样就得到了一个 197x768 的矩阵,作为 ViT 的输入。

  2. 将 ViT 的输入通过 Transformer 的编码器,得到一个 197x768 的输出矩阵。其中第一个向量就是类别 token ,它包含了 ViT 对整个图像的理解。我们将这个向量通过一个线性层和一个 softmax 层,得到最终的分类概率。

  3. 计算类别 token 对目标类别的梯度,即 ∂ y c ∂ A \frac{\partial y_c}{\partial A} Ayc ,其中 y c y_c yc 是目标类别的概率, A A A 是 ViT 的输出矩阵。这个梯度表示了每个 token 对分类结果的重要性。

  4. 对每个 token 的梯度求平均值,得到一个 197 维的向量 w w w ,其中 w i = 1 Z ∑ k ∂ y c ∂ A i k w_i = \frac{1}{Z}\sum_k \frac{\partial y_c}{\partial A_{ik}} wi=Z1kAikyc Z Z Z 是梯度的维度,即 768 。这个向量 w w w 可以看作是每个 token 的权重。

  5. 将 ViT 的输出矩阵和权重向量相乘,得到一个 197 维的向量 s s s ,其中 s i = ∑ k w k A i k s_i = \sum_k w_k A_{ik} si=kwkAik 。这个向量 s s s 可以看作是每个 token 对分类结果的贡献度。

  6. 将贡献度向量 s s s 除去第一个元素(类别 token ),并重塑为一个 14x14 的矩阵 M M M​ ,其中 M i j = s ( i − 1 ) × 14 + j + 1 M_{ij} = s_{(i-1) \times 14 + j + 1} Mij=s(i1)×14+j+1 。这个矩阵 M M M 可以看作是每个小块对分类结果的贡献度。

  7. 将贡献度矩阵 M M M 进行归一化和上采样,得到一个与原始图像大小相同的矩阵 H H H ,其中 H i j = M i j − min ⁡ ( M ) max ⁡ ( M ) − min ⁡ ( M ) H_{ij} = \frac{M_{ij} - \min(M)}{\max(M) - \min(M)} Hij=max(M)min(M)Mijmin(M) 。这个矩阵 H H H 就是我们要求的热力图,它显示了 ViT 在做出分类时最关注的图像区域。

  8. 将热力图 H H H 和原始图像进行叠加,得到一张可视化的图像,可以直观地看到 ViT 的注意力分布。

使用代码

import argparse
import cv2
import numpy as np
import torch

from pytorch_grad_cam import GradCAM, \
                            ScoreCAM, \
                            GradCAMPlusPlus, \
                            AblationCAM, \
                            XGradCAM, \
                            EigenCAM, \
                            EigenGradCAM, \
                            LayerCAM, \
                            FullGrad

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
preprocess_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit

# 加载预训练的 ViT 模型
model = torch.hub.load('facebookresearch/deit:main',
'deit_tiny_patch16_224', pretrained=True)
model.eval()

# 判断是否使用 GPU 加速
use_cuda = torch.cuda.is_available()
if use_cuda:
model = model.cuda()

接下来,我们需要定义一个函数来将 ViT 的输出层从三维张量转换为二维张量,以便 grad-cam 能够处理:

def reshape_transform(tensor, height=14, width=14):
    # 去掉类别标记
    result = tensor[:, 1:, :].reshape(tensor.size(0),
    height, width, tensor.size(2))

    # 将通道维度放到第一个位置
    result = result.transpose(2, 3).transpose(1, 2)
    return result

然后,我们需要选择一个目标层来计算 grad-cam。由于 ViT 的最后一层只有类别标记对预测类别有影响,所以我们不能选择最后一层。我们可以选择倒数第二层中的任意一个 Transformer 编码器作为目标层。在这里,我们选择第 11 层作为示例:


# 创建 GradCAM 对象
cam = GradCAM(model=model,
target_layer=model.blocks[5],
use_cuda=use_cuda,
reshape_transform=reshape_transform)

接下来,我们需要准备一张输入图像,并将其转换为适合 ViT 的格式:

# 读取输入图像
image_path = "cat.jpg"
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))

# 预处理图像
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# 将图像转换为批量形式
input_tensor = input_tensor.unsqueeze(0)
if use_cuda:
input_tensor = input_tensor.cuda()

最后,我们可以调用 cam 对象的 forward 方法,传入输入张量和预测类别(如果不指定,则默认为最高概率的类别),得到 grad-cam 的输出:

# 计算 grad-cam
target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别
grayscale_cam = cam(input_tensor=input_tensor,
target_category=target_category)

# 将 grad-cam 的输出叠加到原始图像上
visualization = show_cam_on_image(rgb_img, grayscale_cam)

# 保存可视化结果
cv2.imwrite('cam.jpg', visualization)

这样,我们就完成了使用 grad-cam 对 ViT 的输出进行可视化的过程。我们可以看到,ViT 主要关注了图像中的猫的头部和身体区域,这与我们的直觉相符。通过使用 grad-cam,我们可以更好地理解 ViT 的工作原理,以及它对不同图像区域的重要性。

Pytorch-grad-cam库的更多方法

除了经典的grad-cam,库里目前支持的方法还有:

MethodWhat it does
GradCAM使用平均梯度对 2D 激活进行加权
GradCAM++类似 GradCAM,但使用了二阶梯度
XGradCAM类似 GradCAM,但通过归一化的激活对梯度进行了加权
EigenCAM使用 2D 激活的第一主成分(无法区分类别,但效果似乎不错)
EigenGradCAM类似 EigenCAM,但支持类别区分,使用了激活 * 梯度的第一主成分,看起来和 GradCAM 差不多,但是更干净
LayerCAM使用正梯度对激活进行空间加权,对于浅层有更好的效果

这里给出MMpretrain提供的对比示例:

image-20230502234318715

在MMpretrain中使用

image-20230502234419413

如果你刚好在用MMpretrain,那么有着方便的脚本文件来帮助你更加方便的进行上面的工作,具体可见:类别激活图(CAM)可视化 — MMPretrain 1.0.0rc7 文档

image-20230502234349894

示例

这里也放一些我自己试过的例子:

bc264783b1bfe850f5b8236c619f8df

下载

总结

通过使用 grad-cam,我们可以更好地理解 ViT 的工作原理,以及它是如何从图像中提取有用的特征的。grad-cam 也可以用于其他基于 Transformer 的模型,例如DeiT、Swin Transformer 等,只需要根据不同的模型结构和输出,调整相应的计算步骤即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/147795.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

22款奔驰E260L升级原厂360全景影像 高清环绕的视野

360全景影像影像系统提升行车时的便利,不管是新手或是老司机都将是一个不错的配置,无论是在倒车,挪车以及拐弯转角的时候都能及时关注车辆所处的环境状况,避免盲区事故发生,提升行车出入安全性。 360全景影像包含&…

ffmpeg5及以上-s和像素格式转换 画屏问题

环境: lsb_release -a No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 22.10 Release: 22.10 Codename: kinetic拉下ffmpeg源码,6.0.1,4.3.6,5.1.4,依次安装作实验 ./configure --disable-x86asm …

香港丽晶酒店于维港畔隆重开幕 绮丽传奇重现香江

2023年11月14日,上海 – 日前,洲际酒店集团旗下香港丽晶酒店于2023年11月8日隆重开幕,历经华丽蜕变后正式回归维多利亚港畔,旧日经典再次伫立,唤醒过去的美好记忆。香港丽晶酒店新生后成功举办了规模盛大的开幕盛典活动…

计算属性传参的写法,在vue3项目中,

计算属性 | Vue.js 在vue3项目中,使用计算属性,在使用这个计算属性时,要传入参数写法有点怪 computed(函数) 函数里面再返回一个函数,这个函数接收参数 注意:最后的结果是,这个计算属性函数并…

【WIFI】MTK WiFi降sar如何开发

1.Sar 简介 SAR即英语“Specific Absorption Rate”的缩写。SAR值一般指手机产品中电磁波所产生的热能,它是对人体产生影响的衡量数据,单位是W/Kg(瓦/公斤)。 对于测量手机产品的“SAR”,通俗地讲,就是测量手机辐射对人体的影响是否符合标准。国际通用的标准为:以6分钟…

confluence无法打开空间目录

confluence无法打开空间目录,打开空间目录后无法显示项目 查看项目的类别信息都在 问题原因 由于索引损坏导致; This issue is caused by acorrupted index. Confluence is trying to fetch information about the spacesfrom the available index, …

如何在Windows 10上恢复丢失的文件?

丢失文件时该怎么办? 在使用电脑的过程中,我们经常会遇到丢失重要文件的情况。无论是意外删除、病毒攻击还是电脑格式化,都可能导致文件丢失。在面对这些情况时,大多数人总是会问:“如何在电脑上恢复丢失的文件&am…

44. Adb调试QT开发的Android程序实用小技巧汇总

1. 说明 使用QT开发Android应用时,如果程序本身出现了问题,很难进行调试。不像在linux或者windows系统中,可以利用QtCreator软件本身进行一些调试,安卓应用一旦在系统中安装后,如果运行中途出现什么BUG,定位问题所在很麻烦。不过,好在有adb这种调试工具可以代替QtCreat…

NineData慢查询分析:数据库性能优化的专家

在日常的数据库运维中,慢查询是一个常见的问题,它可能由复杂的查询语句、不充分的索引设计、大量数据的处理、硬件资源不足等多种因素引起。这些慢查询会消耗大量的数据库服务器资源,甚至可能导致数据库死机,无法响应业务请求。因…

【联邦学习+区块链】TORR: A Lightweight Blockchain for Decentralized Federated Learning

文章目录 I.CONTRIBUTIONII. ASSUMPTIONS AND THREAT MODELA. AssumptionsB. Threat Model III. SYSTEM DESIGNA. Design OverviewB. Block DesignC. InitializationD. Role SelectionE. Storage ProtocolF. Aggregation ProtocolG. Proof of ReliabilityH. Blockchain Consens…

Flink(五)【DataStream 转换算子(上)】

前言 这节注定是一个大的章节,我预估一下得两三天,涉及到的一些东西不懂就重新学,比如 Lambda 表达式,我只知道 Scala 中很方便,但在 Java 中有点发怵了;一个接口能不能 new 来构造对象? 答案是可以的&…

兼容最新 urllib3 版本及相关库

解决方案 对于这个问题,我们可以通过修改setup.py文件来解决。在setup.py文件中,我们将urllib3的版本范围从1.21.1到1.26改为1.21.1到最新版本。这是因为在patch中,我们已经检查了urllib3的版本,并确保其大于1…

编程的简单实例,编程零基础入门教程,中文编程开发语言工具下载

编程的简单实例,编程零基础入门教程,中文编程开发语言工具下载 给大家分享一款中文编程工具,零基础轻松学编程,不需英语基础,编程工具可下载。 这款工具不但可以连接部分硬件,而且可以开发大型的软件&…

深兰科技轮腿家用AI机器人荣获“2023年度城市更新科创大奖”

近日,“2023金砖论坛第五季金立方城市更新科创大会”在上海举行,会上发布了《第12届金砖价值榜》,深兰科技研发出品的轮腿式家用AI机器人(兰宝),因其AI技术的创新性应用,荣获了“2023年度城市更新科创大奖”。 在10月2…

JavaScript数据存储

原始类型:存储在栈内存中,每次开辟的空间大小是固定 引用类型(对象、函数、数组):存储在堆内存中,开辟的空间大小根据数据的大小决定 // 声明变量会在栈内存中开辟空间 // 创建对象在堆内存中开辟空间&…

使用 Cloudflare Worker 免费搭建网址导航网站

开源项目 GitHub:https://github.com/sleepwood/CF-Worker-Dir/ CloudFlare Worker:https://workers.cloudflare.com/ 搭建教程 首先,进入cloud flare - Worker 截图20200224180010.png 在 Cloudflare Worker 管理页面创建一个新的 Work…

11-15 AOP配置

AOP配置 基于xml 切入点表达式:方法签名描述 方法签名:访问修饰符返回值类型〔包.类.]方法名(参数列名)throws 异常声明; 语法: execution(修饰符?返回值 方法名(参数) 异常?) 注意: ?:0或者1个 通配符: * : 任意 用于返回值,方法名,类名 .. : 任意包中使用: ..:表示该包,…

Visual Studio 2019 C# 断点调试代码内存窗口显示无法计算表达式的解决问题

查看如下界面,发现右下角内存1窗口显示无法计算表达式: 按照如下步骤操作即可: 如果s1局部变量此时有值,但是内存窗口还是无法计算表达式我们可以

怎样正确选择等保测评机构开展等保测评工作?

随着大家对网络安全的重视,越来越多的企业需要做等保测评了。很多小伙伴想知道怎样正确选择等保测评机构开展等保测评工作?这里就给大家简单说说。 怎样正确选择等保测评机构开展等保测评工作? 【回答】:正确选择等保测评机构开展…

openGauss Summit 2023 | Call for Sponsor、Speaker、Demo

数据库作为千行万业数据的基石,也是推动数字经济发展的核心。随着数字经济的蓬勃发展,数据库将迎来更加广阔的应用场景和更加迫切的需求。openGauss 社区旨在汇聚产、学、研、用多方力量,聚焦基础软件核心能力的构建,引领国内数据…