【Pytorch】Visualization of Feature Maps(4)——Saliency Maps

在这里插入图片描述

学习参考来自

  • Saliency Maps的原理与简单实现(使用Pytorch实现)
  • https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps

Saliency Maps 原理

《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》(arXiv-2013)

在这里插入图片描述

A saliency map tells us the degree to which each pixel in the image affects the classification score for that image.
To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar)
with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W);
for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel
changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are non-negative.

Saliency Maps相当于是计算图像的每一个pixel是如何影响一个分类器的, 或者说分类器对图像中每一个pixel哪些认为是重要的.

会计算图像每一个像素点的梯度。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,
这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算 saliency map 的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;

因此最后的 saliency map的形状是(H, W)为一个通道的灰度图。


直接来代码,先载入些数据,用的是 cs231n 作业里面的 imagenet_val_25.npz,含有 imagenet 数据中验证集的 25 张图片

import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

SQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def load_imagenet_val(num=None):
    """Load a handful of validation images from ImageNet.
    Inputs:
    - num: Number of images to load (max of 25)
    Returns:
    - X: numpy array with shape [num, 224, 224, 3]
    - y: numpy array of integer image labels, shape [num]
    - class_names: dict mapping integer label to class name
    """
    imagenet_fn = 'imagenet_val_25.npz'
    if not os.path.isfile(imagenet_fn):
      print('file %s not found' % imagenet_fn)
      print('Run the following:')
      print('cd cs231n/datasets')
      print('bash get_imagenet_val.sh')
      assert False, 'Need to download imagenet_val_25.npz'
    f = np.load(imagenet_fn, allow_pickle=True)
    X = f['X']  # (25, 224, 224, 3)
    y = f['y']  # (25, )
    class_names = f['label_map'].item()  # 999
    if num is not None:
        X = X[:num]
        y = y[:num]
    return X, y, class_names

图像的前处理,resize,变成向量,减均值除以方差

# 辅助函数
def preprocess(img, size=224):
    transform = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
                    std=SQUEEZENET_STD.tolist()),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img)

在这里插入图片描述

数据集和实验的模型

链接:https://pan.baidu.com/s/1vb2Y0IiHdH_Fb9wibTta4Q?pwd=zuvw
提取码:zuvw


核心代码,计算 saliency maps

def compute_saliency_maps(X, y, model):
    """
    X表示图片, y表示分类结果, model表示使用的分类模型
    
    Input : 
    - X : Input images : Tensor of shape (N, 3, H, W)
    - y : Label for X : LongTensor of shape (N,)
    - model : A pretrained CNN that will be used to computer the saliency map
    
    Return :
    - saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images
    """
    # 确保model是test模式
    model.eval()
    
    # 确保X是需要gradient
    X.requires_grad_() # 仅开启了输入图片的梯度
    
    saliency = None
    
    logits = model.forward(X)  # torch.Size([5, 1000]), 前向获取 logits
    logits = logits.gather(1, y.view(-1, 1)).squeeze()  # torch.Size([5]) 得到正确分类 logits (5张图片标签相应类别的 logits)
    logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.]))  # 只计算正确分类部分的loss(正确类别梯度为 1 回传)
    
    saliency = abs(X.grad.data)  # 返回X的梯度绝对值大小, torch.Size([5, 3, 224, 224])
    saliency, _ = torch.max(saliency, dim=1)  # torch.Size([5, 224, 224]),取 rgb 3通道的最大值
    
    return saliency.squeeze()

显示 saliency maps

def show_saliency_maps(X, y):
    # Convert X and y from numpy arrays to Torch Tensors
    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # torch.Size([5, 3, 224, 224])
    y_tensor = torch.LongTensor(y)

    # Compute saliency maps for images in X
    saliency = compute_saliency_maps(X_tensor, y_tensor, model)

    # Convert the saliency map from Torch Tensor to numpy array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = X.shape[0]  # 5
    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.title(class_names[y[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()

下面开始调用,首先载入模型,使其梯度冻结,仅打开输入图片的梯度,这样反向传播的时候会更新图片,得到我们想要的 saliency maps

# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)

# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():
    param.requires_grad = False

加载一些图片看看,25 张中抽出来 5 张

X, y, class_names = load_imagenet_val(num=5)  # X: (5, 224, 224, 3) | y: (5,) | class_names: 999

"show images"

plt.figure(figsize=(12, 6))
for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.imshow(X[i])
    plt.title(class_names[y[i]])
    plt.axis('off')
plt.gcf().tight_layout()
plt.show()

显示图片
在这里插入图片描述
把五张图片的 saliency maps 画出来

show_saliency_maps(X, y)

我把 25 张都画出来了
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


核心代码中涉及到了 gather 函数,下面来个简单的例子就明白了

# Example of using gather to select one entry from each row in PyTorch
# 用来返回matrix指定行某个位置的值
import torch

def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C) # 随机生成 4 行 5 列的 tensor
    y = torch.LongTensor([1, 2, 1, 3])
    print(s)
    print(y)
    print(torch.LongTensor(y).view(-1, 1))
    print(s.gather(1, y.view(-1, 1)).squeeze()) # 抽取每行相应的列数位置上的数值


gather_example()

"""
tensor([[ 0.8119,  0.2664, -1.4168, -0.1490, -0.0675],
        [ 0.5335,  0.6304, -0.7200, -0.0974, -0.9934],
        [-0.8305,  0.5189,  0.7359,  1.5875,  0.0505],
        [ 0.4335, -1.1389, -0.7771,  0.5779,  0.3515]])
tensor([1, 2, 1, 3])
tensor([[1],
        [2],
        [1],
        [3]])
tensor([ 0.2664, -0.7200,  0.5189,  0.5779])
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/202958.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Latex中多行公式换行及设置编号位置

Latex中多行公式换行及设置编号位置_latex公式换行_泡泡和善意的博客-CSDN博客文章浏览阅读3.2w次,点赞14次,收藏97次。1. 公式换行公式换行的方式有很多种,介绍三种(1)用equation结合aligned:\begin{equat…

springmvc(基础学习整合)

SpringMVC是Spring框架提供的构建Web应用程序的全功能MVC模块。 在SpringMVC的各个组件中,处理器映射器、处理器适配器、视图解析器称为SpringMVC的三大组件。 springMVC基本介绍: http://t.csdnimg.cn/TOzw9 MVC是一种设计思想,将一个应…

Web端专业级H264/H265 直播流播放器实现-JessibucaPro播放器

概况 这个主要是参加“深圳 liveVideoStack” 的ppt的文字版的分享。 深圳 liveVideoStack 讲师介绍 关于Jessibuca 官网地址:jessibuca.comDemo: DemoDoc:DocGithub地址:Github 关于JessibucaPro 地址:JessibucaProDemo: …

Retrofit中的注解

一、Retrofit中的注解有那些? 方法注解:GET ,POST,PUT,DELETE,PATH,HEAD,OPTIONS,HTTP标记注解:FormUrlEncoded,Multpart,Streaming参数注解:Query,QueryMap,Body,Field…

liunx java 生成图片 中文显示不出来

使用java 生成图片,在图片上打的文字水印显示为一个方框,这种情况的原因,一般是liunx系统或者docker容器内,没有你在打文字水印时选择的字体 解决办法,先找一个免费的字体,比如 Alibaba-PuHuiTi-Regular.otf 然后使用字体 File newFileT new File("Alibaba-PuHuiTi-Re…

pytorch环境下安装node2vec

1.刚开始直接pip install 出错 看到是在安gensim时候出错 2.单独安gensim:https://www.lfd.uci.edu/~gohlke/pythonlibs/ 找到合适的版本,cp36就是python3.6,下载以后放在 3.

Android关于杀掉进程的方案

《风波莫听穿林打叶声》—— 苏轼 〔宋代〕 三月七日,沙湖道中遇雨,雨具先去,同行皆狼狈,余独不觉。已而遂晴,故作此词。 莫听穿林打叶声,何妨吟啸且徐行。 竹杖芒鞋轻胜马,谁怕?一蓑…

开放远程访问MySQL的权限

访问远程数据库时,产生Access denied for user ‘root‘‘xxx.xxx.xxx.xxx‘ (using password: YES)异常的解决办法 一. 异常现象 我编写了一个SpringBoot项目,项目中连接的数据库服务器地址是192.168.87.107,然后打包生成了对应的jar包&am…

Flutter应用程序的加固原理

在移动应用开发中,Flutter已经成为一种非常流行的技术选项,可以同时在Android和iOS平台上构建高性能、高质量的移动应用程序。但是,由于其跨平台特性,Flutter应用程序也面临着一些安全风险,例如反编译、代码泄露、数据…

【Openstack Train安装】九、Nova安装

Nova是OpenStack中最核心的组件,它负责根据需求提供虚拟机服务并管理虚拟机生命周期,包括虚拟机创建、虚拟机调度和热迁移等。 Nova的子组件包括nova-api、nova-compute、nova-scheduler、nova-conductor、nova-db、nova-console等等。 本文介绍Nova安装…

Windows11编译Hadoop3.3.6源码

由于https://github.com/kontext-tech/winutils还未发布3.3.6版本,因此尝试源码编译 目录 环境和安装包准备,见2zlib编译方法一:方法二: 配置文件更改1. maven阿里云镜像2. Node版本3. 越过Javadoc检查 编译HadoopError,其他报错…

每日一题:LeetCode-283. 移动零

每日一题系列(day 08) 前言: 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 &#x1f50e…

一种使用热成像和自动编码器和 3D-CNN 模型堆叠集成进行跌倒检测的新方法

A Novel Approach for Fall Detection Using Thermal Imaging and a Stacking Ensemble of Autoencoder and 3D-CNN Models A Novel Approach for Fall Detection Using Thermal Imaging and a Stacking Ensemble of Autoencoder and 3D-CNN Models:一种使用热成像和自动编码器…

linux用户管理_用户和组

2 用户管理 2.1 用户和组的基本概念和作用 2.1.1 账户实质 账户实质就是一个用户在系统上的标识,系统依据账户ID来区分每个用户的文件、进程、任务,给每个用户提供特定的工作关键(如用户的工作目录、SHELL版本以及环境配置等)&…

ubuntu终端代理配置

ubuntu浏览器的无需手动设置,主要解决在终端中的配置问题,按照下面配置后可能会ping不通一些ip,但wget/git都是可以的,具体原因以后再分析 查找端口 首先要找到自己代理对应的HTTP端口,以QV2ray软件作为示例,我为8889 手动配置 # 配置系统proxy export http_proxy=1…

怎么一键批量转换PDF/图片为Excel、Word,从而提高工作效率?

在处理大量PDF、图片文件时,我们往往需要将这些文件转换成Word或Excel格式以方便编辑和统计分析。此时,金鸣表格文字识别大师这款工具可以发挥巨大作用。下面,我们就来探讨如何使用它进行批量转换,以实现高效处理。 一、准备工作…

flask中路由route根据字典ID展示部分内容,字典名展示全部内容

from flask import Flask, jsonify , request,render_template,app Flask(__name__)app.config[JSON_AS_ASCII] Falsebooks [{"id": 1, "name": 三国演义},{"id": 2, "name": 水浒传},{"id": 3, "name": 西游记…

太极拳的招式有哪些?

太极拳的招式有很多,下面列举一些常见的太极拳招式: 起势:太极拳的第一个动作,从预备姿势开始,身体慢慢放松,重心移至左腿,然后慢慢屈膝,上体屈从向前,双臂自然下垂。 野…

docker (简介、dcoker详细安装步骤、容器常用命令)一站打包- day01

一、 为什么出现 Docker是基于Go语言实现的云开源项目。 Docker的主要目标是“Build,Ship and Run Any App,Anywhere”,也就是通过对应用组件的封装、分发、部署、运行等生命周期的管理,使用户的APP(可以是一个WEB应用或数据库应…