【CAM】CAM(Class Activation Mapping)——可视化CNN的特征定位

文章目录

  • 一、CAM(Class Activation Mapping)
  • 二、CAM技术实现
    • 2.1 网络修改
    • 2.2 微调
    • 2.2 特征提取
  • 三、总结
  • Reference


完整代码见Github :https://github.com/capsule2077/CAM-Visualization ,如果有用可以点个Star,谢谢!


一、CAM(Class Activation Mapping)

CAM是类别激活映射图,提出于论文:Learning Deep Features for Discriminative Localization ,这是一篇发表于2015年的CVPR论文。论文主要强调了卷积神经网络具有的特征定位能力,但是堆叠的全连接层会破环这种能力。为了证明这种能力,作者通过CAM技术把这种能力可视化出来,下面的图很直观,这是ResNet18对最左上角的原图进行预测的Top5的结果,图的上方注明了具体的类别和对应的概率值,预测结果的图以类似热力图的样式进行标注,颜色越深的地方就是模型越关注的特征,针对不同的分类情况模型的关注点都不同。

针对概率值最大的是dome(圆顶),可以看到颜色最深的地方汇聚于建筑物的圆顶部分。而针对palace(宫殿),模型的关注点更加宽泛,深颜色区域遍布整个建筑物。由此可见,CNN确实可以捕捉到圆顶的特征和位置信息。

在这里插入图片描述

二、CAM技术实现

CAM技术其实就是全局平均池化层和一个全连接层来实现;全局平均池化提出于论文:Network In Network,与最大池化相比,全局平均池化层可以保留更多的图像信息,因为它计算每个池化区域的平均值,有助于捕捉整体图像的分布信息。

全局平均池化层一般应用于网络的末端。假设有一个三通道的特征图,将红色特征图所有值算平均得到一个神经元输出,蓝绿特征图类似。有多少个通道就能输出一个相同维度的向量,而这个向量经过softmax就能得到分类的概率值。所以如果要用于ImageNet进行分类,在经过全局平均池化层前的特征图是有1000个通道的,对应1000个类别。
在这里插入图片描述
由于GAP(Global Average Pooling 全局平均池化层)考虑一整个通道的这种特性,那么我在GAP的后面加上一个全连接层,相当于对每个通道(因为GAP得到的向量是由每个通道全局平均得来的)进行加权,其他和普通网络一样。

在这里插入图片描述
举个例子,如果神经网络将图片分类成”狗“,那么把”狗“对应的全连接层权重与GAP前的特征图进行加权计算,权重值代表了某个特征对图像分类为”狗“的重要性。例如Resnet18最后一层特征图有512个通道。这512个通道可以认为提取到不同的特征,该特征具有高度抽象性,且每个通道对最后的结果贡献不同,因此单独可视化每个通道获取热图也让人很难理解。所以CAM技术根据每个通道不同的贡献大小去融合获取一张CAM,类似于热力图可以很清晰的看出网络聚焦的特征。

2.1 网络修改

由于CAM技术要求网络的最后由GAP + FC组成,但是AlexNet、VGG这些网络都不符合,所以需要自行修改网络结构。如何修改网络结构参考:【Pytorch】加载预训练模型及修改网络结构

这里以修改VGG16为例:

def get_vgg(depth = 16):
	# 调用torchvision中的VGG16
    vgg = getattr(torchvision.models, "".join(['vgg', str(depth)]))()
    # 修改维全局平均池化层,即一个通道的特征图最后只输出一个值
    vgg.avgpool.output_size = (1, 1)
    # 修改全连接层
    vgg.classifier = nn.Linear(512, 100)
    
    return vgg

ResNet网络是符合GAP + FC的结构的,所以如果用ResNet进行试验的话可以直接调用预训练模型。

2.2 微调

由于修改了网络结构,所以需要对参数进行微调。微调可以只微调全连接层,其余层加载官方的预训练权重。训练时只需要将全连接层的参数传入SGD优化器即可model.classifier.parameters(),我使用预训练的权重在MINI-ImageNet数据集上做100分类的微调 ,MINI-ImageNet参考之前的记录:【代码实验】CNN实验——利用Imagenet子集训练分类网络(AlexNet/ResNet)

微调后的权重文件:VGG16_100.pth,提取码:3v0i

	...
 # 训练时只需要将全连接层的参数传入SGD优化器即可
  optimizer = optim.SGD(model.classifier.parameters(), lr=args.learning_rate,
                            momentum=0.9, weight_decay=0.0005) 
    ...

2.2 特征提取

由于需要获取GAP层前的特征图,所以需要获取模型的中间输出,参考之前的记录:【Pytorch】六行代码实现:特征图提取与特征图可视化


以上三点是前置知识,可以更好地理解代码


三、总结

这里没有对CAM技术的理论进行详细解释,更推荐看原论文。如果只想应用的话,完整代码见Github :https://github.com/capsule2077/CAM-Visualization,代码中有注释,欢迎提出问题,如果有用可以点个Star,谢谢!

Reference

万字长文:特征可视化技术(CAM)
神经网络可视化——CAM及其变体

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/81529.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

视频转云存的痛点

现在运营商体系里面,有大量的视频转云存储的需求,但是视频云存储有一个比较大的痛点,就是成本! 成本一:存储成本; 我们以1000路2M视频转云存,存储时间为90天为例(B端存储时间有时候…

Java | IDEA中 jconsole 不是内部或外部命令,也不是可运行的程序

解决办法: 1.先将Terminal的Shell path 修改为C:\WINDOWS\system32\cmd.exe 2.在检查环境变量中的ComSpec的值 3.找到自己电脑下载的jdk的bin的地址 4.将jdk的bin地址加入到系统变量path中

仪表板展示 | DataEase看中国:2023年中国电影市场分析

背景介绍 随着《消失的她》、《变形金刚:超能勇士崛起》、《蜘蛛侠:纵横宇宙》、《我爱你》等国内外影片的上映,2023年上半年的电影市场也接近尾声。据国家电影专资办初步统计,上半年全国城市院线票房达262亿元,已经超…

Mariadb高可用MHA (四十二)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 一、概述 1.1 概念 1.2 组成 1.3 特点 1.4 工作原理 二、构建MHA 2.1 ssh免密登录 2.2 主从复制 2.3 MHA安装 2.3.1所有节点安装perl环境 2.3..2 node 2.3.…

SpringBoot + Vue 微人事权限组管理模块 (十四)

权限组前端页面制作 权限组管理角色和菜单之间关系&#xff0c;操作员管理着用户和角色之间的关系。 英文的输入框要有个前缀&#xff0c;SpringSecurity里角色英文名需要加一个ROLE_的前缀 上代码 <div><div class"permissManaTool"><el-input pla…

完全备份、增量备份、差异备份、binlog日志

Top NSD DBA DAY06 案例1&#xff1a;完全备份与恢复案例2&#xff1a;增量备份与恢复案例3&#xff1a;差异备份与恢复案例4&#xff1a;binlog日志 1 案例1&#xff1a;完全备份与恢复 1.1 问题 练习物理备份与恢复练习mysqldump备份与恢复 1.2 方案 在数据库服务器192…

【图论】Floyd算法

一.简介 Floyd算法&#xff0c;也称为Floyd-Warshall算法&#xff0c;是一种用于解决所有节点对最短路径问题的动态规划算法。它可以在有向图或带权图中找到任意两个节点之间的最短路径。 Floyd算法的基本思想是通过中间节点逐步优化路径长度。它使用一个二维数组来存储任意两…

java面试基础 -- ArrayList 和 LinkedList有什么区别, ArrayList和Vector呢?

目录 基本介绍 有什么不同?? ArrayList的扩容机制 ArrayLIst的基本使用 ArrayList和Vector 基本介绍 还记得我们的java集合框架吗, 我们来复习一下, 如图: 可以看出来 ArrayList和LinkedList 都是具体类, 他们都是接口List的实现类. 但是他们底层的逻辑是不同的, 相信…

RabbitMq交换机类型介绍

RabbitMq交换机类型介绍 在RabbitMq中&#xff0c;生产者的消息都是通过交换器来接收&#xff0c;然后再从交换器分发到不同的队列&#xff0c;再由消费者从队列获取消息。这种模式也被成为“发布/订阅”。 分发的过程中交换器类型会影响分发的逻辑。 直连交换机&#xff1a…

【Go】Go 文本匹配 - 正则表达式基础与编程中的应用 (8000+字)

正则表达式&#xff08;Regular Expression, 缩写常用regex, regexp表示&#xff09;是计算机科学中的一个概念&#xff0c;很多高级语言都支持正则表达式。 目录 何为正则表达式 语法规则 普通字符 字符转义 限定符 定位符 分组构造 模式匹配 regexp包 MatchString…

什么是Eureka?以及Eureka注册服务的搭建

导包 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 htt…

LeetCode_Java_2236. 判断根结点是否等于子结点之和

2236. 判断根结点是否等于子结点之和 给你一个 二叉树 的根结点 root&#xff0c;该二叉树由恰好 3 个结点组成&#xff1a;根结点、左子结点和右子结点。 如果根结点值等于两个子结点值之和&#xff0c;返回 true &#xff0c;否则返回 false 。 示例1 输入&#xff1a;roo…

java八股文面试[JVM]——JVM内存结构

参考&#xff1a; JVM学习笔记&#xff08;一&#xff09;_卷心菜不卷Iris的博客-CSDN博客 JVM是运行在操作系统之上的&#xff0c;它与硬件没有直接的交互 JVM内存结构&#xff1a; 方法区&#xff1a;存储已被虚拟机加载的类元数据信息(元空间) 堆&#xff1a;存放对象实…

SpringBoot项目集成ElasticSearch服务

本文已收录于专栏 《中间件合集》 目录 版本介绍背景介绍优势说明集成过程1.引入依赖2.添加配置文件3.初始化 示例说明代码结果 总结提升 版本介绍 Spring boot的版本是&#xff1a; 2.3.12   ElasticSearch的版本是&#xff1a;7.6.2 背景介绍 在我们的项目中经常会遇到对于…

BOXTRADE-天启量化分析平台 系统功能预览

BOXTRADE-天启量化分析平台 系统功能预览 系统功能预览 1.登录 首页 参考登录文档 2. A股 行情与策略分析 2.1 A股股票列表 可以筛选和搜索 2.2 A股行情及策略回测 2.2.1 行情数据提供除权和前复权&#xff0c;后复权数据&#xff1b;外链公司信息 2.2.2 内置策略执行结果…

中国剩余定理及扩展

目录 中国剩余定理解释 中国剩余定理扩展——求解模数不互质情况下的线性方程组&#xff1a; 代码实现&#xff1a; 互质&#xff1a; 非互质&#xff1a; 中国剩余定理解释 在《孙子算经》中有这样一个问题&#xff1a;“今有物不知其数&#xff0c;三三数之剩二&#x…

计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下计算机视觉的应用11-基于pytorch框架的卷积神经网络与注意力机制对街道房屋号码的识别应用&#xff0c;本文我们借助PyTorch&#xff0c;快速构建和训练卷积神经网络&#xff08;CNN&#xff09;等模型&#xff0c;…

Quivr 基于GPT和开源LLMs构建本地知识库 (更新篇)

一、前言 自从大模型被炒的越来越火之后&#xff0c;似乎国内涌现出很多希望基于大模型构建本地知识库的需求&#xff0c;大概在5月底的时候&#xff0c;当时Quivr发布了第一个0.0.1版本&#xff0c;第一个版本仅仅只是使用LangChain技术结合OpenAI的GPT模型实现了一个最基本的…

【搭建WebDAV服务手机ES文件浏览器远程访问】

文章目录 1. 安装启用WebDAV2. 安装cpolar3. 配置公网访问地址4. 公网测试连接5. 固定连接公网地址6. 使用固定地址测试连接 有时候我们想通过移动设备访问群晖NAS 中的文件,以满足特殊需求,我们在群辉中开启WebDav服务,结合cpolar内网工具生成的公网地址,通过移动客户端ES文件…

在ubuntu中将dict.txt导入到数据库sqlite3

将dict.txt导入到数据库 #include <head.h> #include <sqlite3.h> int do_insert(int i,char *str,sqlite3 *db); int main(int argc, const char *argv[]) {//创建泵打开一个数据库sqlite3 *db NULL;if(sqlite3_open("./my.db",&db) ! SQLITE_OK){…