深度学习绘制热力图heatmap、使模型具有可解释性

思路

获取想要解释的那一层的特征图,然后根据特征图梯度计算出权重值,加在原图上面。

Demo

在这里插入图片描述
加上类激活(cam)
在这里插入图片描述
可以看到,cam将模型认为有利于分类的特征标注了出来。
下面以ResNet50为例:
Trick:
使用

for i in model._modules.items():

可以获得模型名称和对应层。

# coding: utf-8
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
 
import torch
import torch.autograd as autograd
import torchvision.transforms as transforms

import torchvision.models as models
 
 
# 训练过的模型路径
#resume_path = r"D:\TJU\GBDB\set113\cross_validation\test1\epoch_0257_checkpoint.pth.tar"
# 输入图像路径
single_img_path = r'bicycle.jpg'
# 绘制的热力图存储路径
save_path = r'heatmap/bicycle_layer4.jpg'
 
# 网络层的层名列表, 需要根据实际使用网络进行修改
layers_names = ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']
# 指定层名
out_layer_name = "layer4"
 
features_grad = 0
 
 
# 为了读取模型中间参数变量的梯度而定义的辅助函数
def extract(g):
    global features_grad
    features_grad = g
 
 
def draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False, out_layer=None):
    """
    绘制 Class Activation Map
    :param model: 加载好权重的Pytorch model
    :param img_path: 测试图片路径
    :param save_path: CAM结果保存路径
    :param transform: 输入图像预处理方法
    :param visual_heatmap: 是否可视化原始heatmap(调用matplotlib)
    :return:
    """
    # 读取图像并预处理
    global layer2
    img = Image.open(img_path).convert('RGB')
    if transform:
        img = transform(img)
    img = img.unsqueeze(0)  # (1, 3, 448, 448)
 
    # model转为eval模式
    model.eval()
 
    # 获取模型层的字典
    layers_dict = {layers_names[i]: None for i in range(len(layers_names))}
    for name,module in model._modules.items():
        #print(i, (name, module))
        layers_dict[name] = module
 
    # 遍历模型的每一层, 获得指定层的输出特征图
    # features: 指定层输出的特征图, features_flatten: 为继续完成前端传播而设置的变量
    features = img
    start_flatten = False
    features_flatten = None

    for name, layer in layers_dict.items():
        if name != out_layer and start_flatten is False:    # 指定层之前
            features = layer(features)
        elif name == out_layer and start_flatten is False:  # 指定层
            features = layer(features)
            start_flatten = True
        else:   # 指定层之后
            if name == "fc":
                break
            if features_flatten is None:
                features_flatten = layer(features)
            else:
                features_flatten = layer(features_flatten)
    #print(features_flatten.shape)
    features_flatten = torch.flatten(features_flatten, 1)
    #print(features_flatten.shape)
    output = model.fc(features_flatten)
    # 预测得分最高的那一类对应的输出score
    pred = torch.argmax(output, 1).item()
    pred_class = output[:, pred]
 
    # 求中间变量features的梯度
    # 方法1
    # features.register_hook(extract)
    # pred_class.backward()
    # 方法2
    features_grad = autograd.grad(pred_class, features, allow_unused=True)[0]
 
    grads = features_grad  # 获取梯度
    pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))
    # 此处batch size默认为1,所以去掉了第0维(batch size维)
    pooled_grads = pooled_grads[0]
    features = features[0]
    print("pooled_grads:", pooled_grads.shape)
    print("features:", features.shape)
    # features.shape[0]是指定层feature的通道数
    for i in range(features.shape[0]):
        features[i, ...] *= pooled_grads[i, ...]
 
    # 计算heatmap
    heatmap = features.detach().cpu().numpy()
    heatmap = np.mean(heatmap, axis=0)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
 
    # 可视化原始热力图
    if visual_heatmap:
        plt.matshow(heatmap)
        plt.show()
 
    img = cv2.imread(img_path)  # 用cv2加载原始图像
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原始图像
    superimposed_img = heatmap * 0.7 + img  # 这里的0.4是热力图强度因子
    cv2.imwrite(save_path, superimposed_img)  # 将图像保存到硬盘
 
 
if __name__ == '__main__':
    model = models.resnet50(pretrained=True)
    #model.eval()
    transform = transforms.Compose([
        transforms.Resize(448),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    # 构建模型并加载预训练参数
    #seresnet50 = FineTuneSEResnet50(num_class=113).cuda()
    #checkpoint = torch.load(resume_path)
    #seresnet50.load_state_dict(checkpoint['state_dict'])
    draw_CAM(model, single_img_path, save_path, transform=transform, visual_heatmap=True, out_layer=out_layer_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/483846.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

二十三 超级数据查看器 讲解稿 设置

二十三 超级数据查看器 讲解稿 设置 ​点击此处 以新页面 打开B站 播放当前教学视频 点击访问app下载页面 百度手机助手 下载地址 大家好,这节课我们讲一下,超级数据查看器的设置功能。 首先,我们打开超级数据查看器, 我…

2023年全国青少年信息素养大赛(python)初赛真题

选择题(每题5分,共20题,满分100分) 1、关于列表的索引,下列说法正确的是? A.列表的索引从0开始 B.列表的索引从1开始 C.列表中可能存在两个元素的索引一致 D&#xff0…

第四百一十九回

文章目录 1. 概念介绍2. 思路与方法2.1 实现思路2.2 实现方法 3. 示例代码4. 内容总结 我们在上一章回中介绍了"自定义标题栏"相关的内容,本章回中将介绍自定义Action菜单.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 我们在这里提到的…

web自动化3-pytest前后置夹具

一、pytest前后置(夹具)-fixture 夹具的作用:在用例执行之前和之后,需要做的准备工作和收尾工作。 用于固定测试环境,以及清理回收资源。 举个例子:访问一个被测页面-登录页面,执行测试用例过…

SpringCloud-Gateway服务网关

一、网关介绍 1. 为什么需要网关 Gateway网关是我们服务的守门神,所有微服务的统一入口。 网关的核心功能特性: 请求路由 权限控制 限流 架构图: 权限控制:网关作为微服务入口,需要校验用户是是否有请求资格&am…

算法-双指针

目录 1、双指针遍历分割:避免开空间,原地处理 2、快慢指针:循环条件下的判断 3、左右指针(对撞指针):分析具有单调性,避免重复计算 双指针又分为双指针遍历分割,快慢指针和左右指针 1、双指…

深度学习 tablent表格识别实践记录

下载代码:https://github.com/asagar60/TableNet-pytorch 下载模型:https://drive.usercontent.google.com/download?id13eDDMHbxHaeBbkIsQ7RSgyaf6DSx9io1&exportdownload&confirmt&uuid1bf2e85f-5a4f-4ce8-976c-395d865a3c37 原理&#…

C# 将 Word 转文本存储到数据库并进行管理

目录 功能需求 范例运行环境 设计数据表 关键代码 组件库引入 Word文件内容转文本 上传及保存举例 得到文件Byte[]数据方法 查询并下载Word文件 总结 功能需求 将 WORD 文件的二进制信息存储到数据库里,即方便了统一管理文件,又可以实行权限控…

查看文件内容的指令:cat,tac,nl,more,less,head,tail,写入文件:echo

目录 cat 介绍 输入重定向 选项 -b -n -s tac 介绍 输入重定向 nl 介绍 示例 more 介绍 选项 less 介绍 搜索文本 选项 head 介绍 示例 选项 -n tail 介绍 示例 选项 echo 介绍 输出重定向 追加重定向 cat 介绍 将标准输入(键盘输入)的内容打…

【微服务】Gateway服务网关

📝个人主页:五敷有你 🔥系列专栏:微服务 ⛺️稳中求进,晒太阳 Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2.0 和 Project Reactor 等响…

单目深度估计基础理论和论文学习总结

单目深度估计基础理论和论文学习总结 一、背景知识: 三维刚体运动的数学表示:旋转平移矩阵、旋转向量、欧拉角、四元数、轴角模型、齐次坐标、各种变换等 照相机模型:单目/双目模型,单目中的世界坐标系/相机坐标系/图像坐标系的…

MySQL表的增删改查---多表查询和联合查询

꒰˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN …

保研复习概率论1

1.什么是随机试验(random trial)? 如果一个试验满足试验可以在相同的条件下重复进行、试验所有可能结果明确可知(或者是可知这个范围)、每一次试验前会出现哪个结果事先并不确定,那么试验称为随机试验。 …

ssm+vue的消防物资存储系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频: ssmvue的消防物资存储系统(有报告)。Javaee项目,ssm vue前后端分离项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构&…

PyQT5学习--新建窗体模板

目录 1 Dialog 2 Main Window 3 Widget Dialog 模板,基于 QDialog 类的窗体,具有一般对话框的特性,如可以模态显示、具有返回值等。 Main Window 模板,基于 QMainWindow 类的窗体,具有主窗口的特性,窗口…

重生奇迹mu弓箭手技能

1、弓箭手职业技能:多重箭:同时射出三发弓箭,给予复数敌人伤害,根据弓的不同,射出的数量也不同。天堂之箭:弓箭垂直射向天际,准确的落在敌人的头顶上,造成严重的伤害。 2、连技技能…

动态规划之数字三角形模型

题目:1015. 摘花生 思路 很经典的动态规划问题。 定义:v[i][j]表示位置是i,j的花生数量,f[i][j]表示走到位置i,j所能获得的最大花生数量。初始状态:f[1][1],目标状态:f[n][m]状态转移:由于题目规定只能向…

2024-03-24 需求分析-智能问答系统-调研

一. 需求列表 基于本地知识库的问答系统对接外围系统 数字人语音识别二. 待调研的公司 2.1 音视贝 AI智能外呼_大模型智能客服系统_大模型知识库系统_杭州音视贝 (yinshibei.com) 2.2 得助智能 智能AI客服机器人-智能电话机器人客服-电话电销机器人-得助智能 (51ima.com) 2…

Redis常见数据类型(1)

Redis提供了5种数据结构, 理解每种数据类型的特点对于Redis开发运维非常重要, 同时掌握每种数据类型的常见命令, 会在使用Redis的时候做到游刃有余. 内容如下: 预备知识: 几个全局命令, 数据结构和内部编码, 单线程机制解析. 5种数据类型的特点, 命令使用, 应用场景示例. 键遍历…

03-SparkSQL入门

0 Shark Spark 的一个组件,用于大规模数据分析的 SQL 查询引擎。Shark 提供了一种基于 SQL 的交互式查询方式,可以让用户轻松地对大规模数据集进行查询和分析。Shark 基于 Hive 项目,使用 Hive 的元数据存储和查询语法,并基于Hiv…