【论文阅读】DETR 论文逐段精读

【论文阅读】DETR 论文逐段精读

文章目录

  • 【论文阅读】DETR 论文逐段精读
  • 📖DETR 论文精读【论文精读】
    • 🌐前言
    • 📋摘要
    • 📚引言
    • 🧬相关工作
    • 🔍方法
      • 💡目标函数
      • 📜模型结构
      • ⚙️代码
    • 📌实验

参考跟李沐学AI: 精读DETR

📖DETR 论文精读【论文精读】


🌐前言

目标检测领域:从目标检测开始火到 detr 都很少有端到端的方法,大部分方法最后至少需要后处理操作(NMS, non-maximum suppression 非极大值抑制)。有了 NMS,模型调参就会很复杂,而且即使训练好了一个模型,部署起来也非常困难(NMS 不是所有硬件都支持)。

📋摘要

贡献:把目标检测做成一个端到端的框架,把之前特别依赖人的先验知识的部分删掉了(NMS 部分、anchor)。

DETR提出

  • 新的目标函数,通过二分图匹配的方式,强制模型输出一组独一无二的预测(没有那么多冗余框,每个物体理想状态下就会生成一个框)
  • 使用 encoder-decoder 的架构

两个小贡献:

  1. decoder 还有另外一个输入 learned object query,类似 anchor 的意思
    (给定这些object query之后,detr就可以把learned object query和全局图像信息结合一起,通过不同的做注意力操作,从而让模型直接输出最后的一组预测框)
  2. 想法&&实效性:并行比串行更合适

DETR 的好处:

  1. 简单性:想法上简单,不需要一个特殊的 library,只要硬件支持 transformer 或 CNN,就一定支持 detr
  2. 性能:在 coco 数据集上,detr 和一个训练非常好的 faster RCNN 基线网络取得了差不多的效果,模型内存和速度也和 faster RCNN 差不多
  3. 想法好,解决了目标检测领域很多痛点,写作好
  4. 别的任务:全景分割任务上 detr 效果很好,detr 能够非常简单拓展到其他任务上

📚引言


DETR 流程(训练)

  1. CNN 提特征
  2. 特征拉直,送到 encoder-decoder 中,encoder 作用:进一步学习全局信息,为近下来的 decoder,也就是最后出预测框做铺垫。
  3. decoder 生成框的输出,当你有了图像特征之后,还会有一个 object query(限定了你要出多少框),通过 query 和特征在 decoder 里进行自注意力操作,得到输出的框(文中是100,无论是什么图片都会预测100个框)
  4. loss :二分图匹配,计算100个预测框和2个 GT 框的 matching loss,决定100个预测框哪两个是独一无二对应到红黄色的 GT 框,匹配的框去算目标检测的 loss

推理
1、2、3一致,第四步 loss 不需要,直接在最后的输出上用一个阈值卡一个输出的置信度,置信度比较大(>0.7的)保留,置信度小于0.7的当做背景物体。

🧬相关工作

让 DETR 成功主要原因:transformer

🔍方法

分两块:1、基于集合的目标函数怎么做,作者如何通过二分图匹配把预测的框和 GT 框连接在一起,算得目标函数 2、detr 具体模型架构

💡目标函数

DETR模型最后输出是一个固定集合,无论图片是什么,最后都会输出 n 个(本文 n=100)

问题:detr 每次都会出 100 个输出,但是实际上一个图片的 GT 的 bounding box 可能只有几个,如何匹配?如何计算 loss?怎么知道哪个预测框对应 GT 框?
匈牙利算法是解决该问题的一个知名且高效的算法,能够以较低的复杂度得到唯一的最优解。
在 scipy 库中,已经封装好了匈牙利算法,只需要将成本矩阵(cost matrix)输入进去就能够得到最优的排列。在 DETR 的官方代码中,也是调用的这个函数进行匹配(from scipy.optimize import linear_sum_assignment)。
从N个预测框中,选出与M个GT Box最匹配的预测框,也可以转化为二分图匹配问题,这里需要填入矩阵的“成本”,就是每个预测框和GT Box的损失。对于目标检测问题,损失就是分类损失和边框损失组成。

所以整个步骤就是:

  • 遍历所有的预测框和 GT Box,计算其 loss。
  • 将 loss 构建为 cost matrix,然后用 scipy 的 linear_sum_assignment(匈牙利算法)求出最优解,即找到每个 GT Box 最匹配的那个预测框。
  • 计算最优的预测框和 GT Box 的损失。(分类+回归)

但是在 DETR 中,损失函数有两点小改动:

  • 去掉分类损失中的 log
  • 回归损失为 L1 loss+GIOU

📜模型结构

下面参考官网的一个 demo,以输入尺寸3×800×1066为例进行前向过程:

  • CNN 提取特征([800,1066,3]→[25,34,256]
    backbone 为 ResNet-50,最后一个 stage 输出特征图为 25×34×2048(32 倍下采样),然后用 1×1 的卷积将通道数降为 256;
  • Transformer encoder 计算自注意力([25,34,256]→[850,256]
    将上一步的特征拉直为 850×256,并加上同样维度的位置编码(Transformer 本身没有位置信息),然后输入的 Transformer encoder 进行自注意力计算,最终输出维度还是 850×256;
  • Transformer decoder 解码,生成预测框
    decoder 输入除了 encoder 部分最终输出的图像特征,还有前面提到的 learned object query,其维度为 100×256。在解码时,learned object query 和全局图像特征不停地做 across attention,最终输出 100×256 的自注意力结果。
    这里的 object query 即相当于之前的 anchor/proposal,是一个硬性条件,告诉模型最后只得到 100 个输出。然后用这 100 个输出接 FFN 得到分类损失和回归损失。
  • 使用检测头输出预测框
    检测头就是目标检测中常用的全连接层(FFN),输出 100 个预测框( h x c e n t e r , y c e n t e r , w , h h x_{center}, y_{center}, w, h hxcenter,ycenter,w,h )和对应的类别。
  • 使用二分图匹配方式输出最终的预测框,然后计算预测框和真实框的损失,梯度回传,更新网络。

除此之外还有部分细节:

  • Transformer-encode/decoder 都有 6层
  • 除第一层外,每层 Transformer encoder 里都会先计算 object query 的 self-attention,主要是为了移除冗余框。这些 query 交互之后,大概就知道每个 query 会出哪种框,互相之间不会再重复(见实验)。
  • decoder 加了 auxiliary loss,即每层 decoder 输出的 100×256 维的结果,都加了 FFN 得到输出,然后去计算 loss,这样模型收敛更快。(每层 FFN 共享参数)

⚙️代码

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
    def __init__(self, num_classes, hidden_dim, nheads,
        num_encoder_layers, num_decoder_layers):
        super().__init__()
        # We take only convolutional layers from ResNet-50 model
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = nn.Conv2d(2048, hidden_dim, 1) # 1×1卷积层将2048维特征降到256维
        self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # 类别FFN
        self.linear_bbox = nn.Linear(hidden_dim, 4)                # 回归FFN
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # object query
        # 下面两个是位置编码
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
       					 self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
       					 ], dim=-1).flatten(0, 1).unsqueeze(1) # 位置编码
       					 
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))
        return self.linear_class(h), self.linear_bbox(h).sigmoid()


detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

📌实验

  • 最上面一部分是 Detectron 2 实现的 Faster RCNN ,但是本文中作者使用了很多 trick
  • 中间部分是作者使用了 GIoU loss、更强的数据增强策略、更长的训练时间来把上面三个模型重新训练了一次,这样更显公平。重新训练的模型以+表示,参数量等这些是一样的,但是普偏提了两个点
  • 下面部分是 DETR 模型,可以看到参数量、GFLOPS 更小,但是推理更慢。模型比 Faster RCNN 精度高一点,主要是大物体检测提升 6 个点 AP,小物体相比降低了 4个点左右

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/508708.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Navicat工具使用

Navicat的本质: 在创立连接时提前拥有了数据库用户名和密码 双击数据库时,相当于建立了一个链接关系 点击运行时,远程执行命令,就像在xshell上操作Linux服务器一样,将图像化操作转换成SQL语句去后台执行 一、打开Navi…

plasmo内容UI组件层级过高导致页面展示错乱

我使用plasmo写了一个行内样式的UI组件,但是放到页面上之后,会和下拉组件出现层级错乱,看了一下样式,吓我一跳:层级竟然设置的如此之高 所以就需要将层级设置低一点: #plasmo-shadow-container {z-index: …

Java 操作 Hadoop 集群之 HDFS 的应用案例详解

Java 操作 Hadoop 注意:本文内容基于 Hadoop 集群搭建完成基础上: Linux 系统 CentOS7 上搭建 Hadoop HDFS集群详细步骤 本文的内容是基于下面前提: Hadoop 集群搭建完成并能正常启动和访问Idea 和 Maven 分别安装完成需要有 JavaSE 基础和熟悉操作hadoop 的 hdfs dfs 命令…

实验04_OSPF&RIP选路实验

实验拓扑 IP地址规划 拓扑中的 IP 地址段采用:172.16.AB.X/24。其中 AB 为两台路由器编号组合,例如:R3-R6 之间的 AB 为 36,X 为路由器编号,例如R3 的 X3所有路由器都有一个 loopback 0 接口,地址格式为&…

在 Three.js 中,`USDZExporter` 类用于将场景导出为 USDZ 格式,这是一种用于在 iOS 平台上显示增强现实(AR)内容的格式。

demo 案例 在 Three.js 中,USDZExporter 类用于将场景导出为 USDZ 格式,这是一种用于在 iOS 平台上显示增强现实(AR)内容的格式。下面是关于 USDZExporter 的入参、出参、方法和属性的讲解: 入参 (Parameters): sc…

【Frida】【Android】 07_爬虫之网络通信库HttpURLConnection

🛫 系列文章导航 【Frida】【Android】01_手把手教你环境搭建 https://blog.csdn.net/kinghzking/article/details/136986950【Frida】【Android】02_JAVA层HOOK https://blog.csdn.net/kinghzking/article/details/137008446【Frida】【Android】03_RPC https://bl…

WSL Ubuntu20 使用1panelSSH连接失败(SSH服务初始化配置)

文章目录 安装网络工具ssh配置ssh服务安装 配置信息(命令行)配置信息(可视化)基础配置(可省过)高级配置(必须) 面板中终端配置SSH连接 安装网络工具 安装net工具apt install net-to…

行人重识别项目 | 基于Pytorch实现ReID行人重识别算法

项目应用场景 面向行人重识别场景,项目具有轻量化 (训练的时候也只需要 2GB 的显存占用)、性能好 (只使用 softmax 损失就能够达到 Rank188.24%, mAP70.68%),另外提供友好的上手项目流程教程 项目效果: 项目流程 > 具体参见项目内README.…

书生·浦语大模型全链路开源体系-第2课

书生浦语大模型全链路开源体系-第2课 书生浦语大模型全链路开源体系-第2课相关资源实战部署InternLM2-Chat-1.8B模型准备环境下载模型运行案例 实战部署InternLM2-Chat-7B模型准备环境下载模型及案例代码运行cli案例代码运行web案例代码配置SSH公钥信息配置SHH隧道连接 熟悉 Hu…

Echarts实现高亮某一个点

背景 接口会返回所有点的数据,以及最优点的数据。产品要求在绘制图形后,高亮最优点,添加一个红色的样式,如图。点击select选择器时,可选择不同指标和花费对应的关系。 以下介绍实现思路 1、自定义配置选择器的数据源…

稀碎从零算法笔记Day36-LeetCode:H指数

有点绕的一个题,题目描述的有点奇怪(可以看下英文?) 题型:数组、模拟 链接:274. H 指数 - 力扣(LeetCode) 来源:LeetCode 题目描述 给你一个整数数组 citations &am…

ArcGIS Pro怎么进行挖填方计算

在工程实施之前,我们需要充分利用地形,结合实际因素,通过挖填方计算项目的标高,以达到合理控制成本的目的,这里为大家介绍一下ArcGIS Pro中挖填方计算的方法,希望能对你有所帮助。 数据来源 教程所使用的…

Android仿高德首页三段式滑动

最近发现很多app都使用了三段式滑动,比如说高德的首页和某宝等物流信息都是使用的三段式滑动方式,谷歌其实给了我们很好的2段式滑动,就是BottomSheet,所以这次我也是在这个原理基础上做了一个小小的修改来实现我们今天想要的效果。…

13.5k star, 免费开源 Markdown 编辑器

13.5k star, 免费开源 Markdown 编辑器 分类 开源分享 项目名: Editor.md -- Markdown 编辑器 Github 开源地址: https://github.com/pandao/editor.md 在线测试地址: Editor.md - 开源在线 Markdown 编辑器 完整实例: HTML Preview(mark…

LLaMA-Factory微调(sft)ChatGLM3-6B保姆教程

LLaMA-Factory微调(sft)ChatGLM3-6B保姆教程 准备 1、下载 下载LLaMA-Factory下载ChatGLM3-6B下载ChatGLM3windows下载CUDA ToolKit 12.1 (本人是在windows进行训练的,显卡GTX 1660 Ti) CUDA安装完毕后&#xff0c…

iPhone设备中查看应用程序崩溃日志的最佳实践与经验分享

​ 目录 如何在iPhone设备中查看崩溃日志 摘要 引言 导致iPhone设备崩溃的主要原因是什么? 使用克魔助手查看iPhone设备中的崩溃日志 奔溃日志分析 总结 摘要 本文介绍了如何在iPhone设备中查看崩溃日志,以便调查崩溃的原因。我们将展示三种不同的…

RT-Thread 学习二:基于 RT Thread studio (联合cubemx)的stm32l475VETX pwm 呼吸灯实验

1、基于芯片创建呼吸灯项目 基于rtthread studio 可以帮助我们创建好底层驱动,我们只需要做简单配置 2、配置pwm 设备 基于芯片生成的项目默认只打开了 uart 和 pin的驱动,对于其它硬件驱动需要自行配置; 对于pwm和设备的驱动分为四步&am…

深入理解数据结构——堆

前言: 在前面我们已经学习了数据结构的基础操作:顺序表和链表及其相关内容,今天我们来学一点有些难度的知识——数据结构中的二叉树,今天我们先来学习二叉树中堆的知识,这部分内容还是非常有意思的,下面我们…

将 Elasticsearch 向量数据库引入到数据上的 Azure OpenAI 服务(预览)

作者:来自 Elastic Aditya Tripathi Microsoft 和 Elastic 很高兴地宣布,全球下载次数最多的向量数据库 Elasticsearch 是公共预览版中 Azure OpenAI Service On Your Data 官方支持的向量存储和检索增强搜索技术。 这项突破性的功能使你能够利用 GPT-4 …

LeetCode-240. 搜索二维矩阵 II【数组 二分查找 分治 矩阵】

LeetCode-240. 搜索二维矩阵 II【数组 二分查找 分治 矩阵】 题目描述:解题思路一:从左下角或者右上角元素出发,来寻找target。解题思路二:右上角元素,代码解题思路三:暴力也能过解题思路四:二分…