目标检测文献阅读-DETR:使用Transformer进行端到端目标检测

目录

摘要

Abstract

1 引言

2 DETR结构

2.1 Backbone

2.2 Encoder

2.3 Decoder

2.4 FFN

3 目标检测集合预测损失

3.1 二分图匹配损失

3.2 损失函数

总结


摘要

本周阅读的论文题目是《End-to-End Object Detection with Transformers》(使用Transformer进行端到端目标检测)。目前大多数目标检测方法都是两阶段的(例如R-CNN系列),需要先生成候选区域再进行分类和回归,就算是单阶段(例如YOLO系列),最后往往也要进行NMS后处理步骤来去除预测框,而本文中提出了一种将目标检测视为直接的集合预测问题的新方法DEtection TRansformer,即DETR,是基于集合的全局损失,该损失通过二分匹配和Transformer编码器-解码器体系结构来强制进行唯一预测,从而能够进行并行处理,只会输出一个预测框。并且DETR结构灵活简单,易于实施,可以轻松扩展到其他领域,例如全景分割。

Abstract

This week's paper is titled "End-to-End Object Detection with Transformers." At present, most object detection methods are two-stage (e.g., R-CNN series), which need to be formed into candidate regions and then classified and regressed, even if they are single-stage (e.g., YOLO series), and finally NMS post-processing steps are often carried out to remove the prediction box, and this paper proposes a new method that treats object detection as a direct ensemble prediction problem, DEtection TRansformer, i.e., DETR, which is based on the global loss of sets. This loss forces a unique prediction through binary matching and a Transformer encoder-decoder architecture, enabling parallel processing with only one prediction box output. And the DETR structure is flexible and simple, easy to implement, and can be easily extended to other fields, such as panoramic segmentation.

文献链接🔗:End-to-End Object Detection with Transformers

1 引言

目标检测的目标是预测每个感兴趣对象的一组边界框和类别标签。而在本文之前,现代检测器通过在大量proposal、anchor 、窗口中心上定义替代回归和分类问题,以间接方式解决预测任务,例如R-CNN系列、YOLO系列等,导致重复的框太多,需要频繁进行复杂的后处理步骤,影响了目标检测的速度。

而本文中将目标检测任务直接看成是集合预测的问题,从而简化了训练过程。DETR采用基于Transformer的编码器-解码器架构,Transformer的自注意机制明确地对序列中元素之间的所有成对相互作用进行了建模,使这些体系结构特别适合于集合预测的特定约束,例如删除重复的预测。

如下图,DETR通过将公共CNN与Transformer架构相结合,一次预测所有对象,并使用一组损失函数进行端到端训练,直接并行预测并行最终检测集。在训练期间,二分匹配唯一地分配具有地面实况框的预测,而没有匹配的预测应该产生一个“no object”类预测。

54b5530c49b647eaac8e828321cefe1a.png

 DETR主要特征有:

  • Transformer (非自回归)并行解码:在解码器的部分进行并行的出框;
  • 全局的二分匹配损失:在预测对象和真实对象之间执行二分匹配。

2 DETR结构

DETR的模型架构非常简单,使得它能够在几乎所有的深度学习框架下都可以实现,只要有CNN和Transformer就可以了。总体框架如下图,可以拆分为backbone、encoder、decoder和prediction heads四个部分:

e68268f8cd8049c0b25b2d62aa826e98.png

  • Backbone:CNN backbone学习图像的2D特征;
  • Positional Encoding:将2D特征展平,并对其使用位置编码(positional encoding);
  • Encoder:经过Transformer的encoder;
  • Decoder:encoder的输出+object queries作为Transformer的decoder输入;
  • Prediction Heads:将decoder的每个输出都送到FFN去输出检测结果。

2.1 Backbone

backebone就是一个传统的CNN模型,本文中DETR使用Imagenet预训练好的Resnet,作用是抽取图片的特征信息。

通常假设Backbone的输入初始图像eq?x_%7Bimg%7D%5Cin%20%5Cmathbb%7BR%7D%5E%7B3%5Ctimes%20H_0%20%5Ctimes%20W_0%20%7D (即具有3个颜色通道,高和宽分别为eq?H_0 和eq?W_0 ),则输出通道eq?f%5Cin%20%5Cmathbb%7BR%7D%5E%7BC%5Ctimes%20H%20%5Ctimes%20W%20%7D ,通常eq?C%3D2048 、eq?H%3D%5Cfrac%7BH_0%7D%7B32%7D 、eq?W%3D%5Cfrac%7BW_0%7D%7B32%7D (图像高和宽都变为了1/32)。 

self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])

x = self.backbone(inputs)
H, W = x.shape[-2:]

2.2 Encoder

f174a0a7bd5247cea65ac0efa33e29e8.png

经过Backbone后,由于输出特征图eq?f 的eq?C%3D2048 是每个 token 的维度,还是比较大,并且encoder的输入为序列,所以先经过一个eq?1%20%5Ctimes%201 的卷积进行降维到更小的维度eq?d ,得到一个新的特征映射eq?z_0%5Cin%20%5Cmathbb%7BR%7D%5E%7Bd%5Ctimes%20H%20%5Ctimes%20W%20%7D ,然后再输入encoder会更好。

self.conv = nn.Conv2d(2048, hidden_dim, 1)  # 1×1卷积层将2048维特征降到256维

x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]

此时encoder的self-attention在特征图上进行全局分析,因为最后一个特征图对于大物体比较友好,那么在上面进行self-attention会便于网络更好的提取不同位置不同大目标之间的相互关系的联系。所以DETR在大目标上效果比Faster R-CNN好就比较容易理解到了。

self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

由于Transformer架构是置换不变的,所以本文中用添加到每个注意层输入的固定位置编码对其进行补充。为了体现图像在eq?x 和eq?y 维度上的信息,本文中分别计算了两个维度的位置编码,然后cat 到一起:

self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1))

2.3 Decoder

decoder同样遵循Transformer的标准架构,使用多头注意力机制和encoder-decoder注意机制转换大小为eq?d 的eq?N 个嵌入。与原始转换器的不同之处在于,本文中DETR在每个encoder层并行解码eq?N 个对象。由于encoder也是排列不变的,因此eq?N 个输入嵌入必须不同才能产生不同的结果。这些输入嵌入是学习到的位置编码,即查询对象,与encoder类似,本文中将它们添加到每个注意力层的输入中。eq?N 个查询对象由encoder转换为输出嵌入,然后通过FFN将它们独立解码为框坐标和类标签,从而产生 eq?N 个最终预测。在这些嵌入上使用self-attention和encoder-decoder注意力,该模型使用它们之间的成对关系对所有对象进行全局推理,同时能够将整个图像用作上下文。

self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))  # 查询对象

2.4 FFN

最后,FFN是由具有ReLU激活函数和隐藏维数eq?d 的3层感知器和一个线性投影层来计算的,或者说就是 eq?1%20%5Ctimes%201 的卷积。FFN预测框归一化的中心坐标、输入图形框的高度和宽度,然后使用softmax函数激活获得预测类标签。

self.linear_class = nn.Linear(hidden_dim, num_classes + 1)  # 类别FFN
self.linear_bbox = nn.Linear(hidden_dim, 4)  # 回归FFN

本文中DETR主体代码如下:

import torch
from torch import nn
from torchvision.models import resnet50

class DETR(nn.Module):
    def __init__(self, num_classes, hidden_dim, nheads,
        num_encoder_layers, num_decoder_layers):
        super().__init__()
        # We take only convolutional layers from ResNet-50 model
        self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = nn.Conv2d(2048, hidden_dim, 1) # 1×1卷积层将2048维特征降到256维
        self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1) # 类别FFN
        self.linear_bbox = nn.Linear(hidden_dim, 4)                # 回归FFN
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) # object query
        # 下面两个是位置编码
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
       					 self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
       					 ], dim=-1).flatten(0, 1).unsqueeze(1) # 位置编码
       					 
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),self.query_pos.unsqueeze(1))
        return self.linear_class(h), self.linear_bbox(h).sigmoid()


detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)

3 目标检测集合预测损失

DETR在一次通过解码器的过程中推断出eq?N 个固定大小的预测集(本文中eq?N 为100),可以知道eq?N 为明显大于图像中感兴趣对象的实际数量的,那预测出来的框怎么知道对应的是哪一个GT (Ground Truth)框呢,即如何来计算损失?

3.1 二分图匹配损失

首先,要将GT也扩展为eq?N 个检测框,用eq?y (eq?y 大小为eq?N ,且用一个额外的特殊类标签ϕ来表示“no object”)来代表GT的集合,而eq?%5Chat%7By%7D%3D%5Cleft%20%5C%7B%20%5Chat%7By%7D%20%5Cright%20%5C%7D%5EN_%7Bi%3D1%7D 代表eq?N 个预测的集合。这样两个集合的元素数量都为eq?N ,就可以做一个配对的操作,让预测集合的元素都能找到GT集合里的一个配对元素,每个预测集合的元素找到的GT集合里的元素都是不同的,也就是一一对应。

7c3d37f58d424684bf41c210088d49e9.png

这样的组合可以有eq?N%21 种,所有的组合记作eq?%5Cvarsigma_N ,然后搜素具有最小损失的eq?N 个元素eq?%5Csigma%20%5Cin%20%5Cvarsigma_N 的置换:

28c65b4f6d9f4c0da270b16728b354d5.png

其中,eq?%5Cpounds%20_%7Bmatch%7D%28y_i%2C%5Chat%7By%7D_%7B%5Csigma%20%28i%29%7D%29 是GT eq?y_%7Bi%7D  与索引为eq?%5Csigma%20%28i%29 的预测之间的成对匹配损失:

047b2398c33040029b3dc5dbcf870a78.png

 其中:

  • eq?%5Cmathbb%7BI%7D 是示1符号,后面括号的内容为真时取值1,否则取值0;
  • eq?i 表示GT中的第eq?i 个元素;
  • eq?c_i 表示GT中的第eq?i 个类;
  • eq?b_i 表示GT中的第eq?i 个边界框;
  • eq?%5Csigma%20%28i%29 是某个组合中GT第eq?i 个元素对应的预测中的索引;
  • eq?%5Chat%7Bp%7D_%7B%5Csigma%20%28i%29%7D 表示预测中的第eq?%5Csigma%20%28i%29 个结果;
  • eq?%5Chat%7Bp%7D_%7B%5Csigma%20%28i%29%7D%28c_i%29 表示预测中第eq?%5Csigma%20%28i%29 个结果中eq?c_i 的概率;
  • eq?%5Chat%7Bb%7D_%7B%5Csigma%20%28i%29%7D 表示预测中的第eq?%5Csigma%20%28i%29 个边界框。

对于 eq?%5Cpounds%20_%7Bbox%7D ,本文中由于是直接预测边界框,如果像其他方法中直接计算eq?L_1 损失,就会导致对于大的框和小的框的惩罚力度不一致,所以文章在使用eq?L_1 损失的同时,也使用了尺度不变的IoU 损失eq?%5Cpounds%20_%7Biou%7D :

b4c7dd3ec58845ad97de1762e216520d.png

则有eq?%5Cpounds%20_%7Bbox%7D :

d03c46cdee1a4663be3a2d00ccbefe11.png

其中,eq?%5Clambda%20_%7Biou%7D 和eq?%5Clambda%20_%7BL_i%7D 是超参数。

由此,本文中通过采用匈牙利算法来进行二分图匹配,即对预测集合和真实集合的元素进行一一对应,使得匹配损失最小。

3.2 损失函数

然后,就是计算损失函数了,也就是计算上一步得到的所有匹配对之间的匈牙利损失。 本文中将损失函数定义为与常见目标检测器相似的形式,即一个用于类别预测的负对数似然和一个边框损失的线性组合:

8ee6e6bf9dfa404fa40960a1834a2b13.png

其中,eq?%5Chat%7B%5Csigma%20%7D 是匹配损失中计算得到的最优分配。

eq?%5Cpounds%20_%7Bmatch%7D 区别是使用了eq?log ,目的是在计算训练模型的损失函数时需要得到准确的结果,ϕ就是ϕ,而不要似有似无、相近类似,这样会干扰预测的准确性。而在eq?%5Cpounds%20_%7Bmatch%7D 不使用eq?log 则不管ϕ 的预测结果,并且如果使用eq?log 也会增大计算量。

可以得到DETR的训练流程如下:

  1. CNN提取特征:经过CNN提取一部分的特征得到对应的特征图,并将得到的特征进行拉直处理形成token;
  2. 将拉直之后的token添加位置编码送入encoder的结构部分,encoder作用是进一步学习全局信息,为接下来的decoder出预测框做铺垫;
  3. decoder生成框的输出,当有了图像特征之后,还会有一个查询对象(限定了要出多少预测框),通过查询和特征在decoder里进行自注意力操作,得到输出的框(本文中预测框限定为100,无论是什么图片都会预测100个框);
  4. 计算100个预测框和2个GT框的二分图匹配损失,决定100个预测框哪两个是独一无二对应到红黄色的GT框,匹配的框去算目标检测的loss。

而在DETR的推理中,与训练流程的1、2、3一致,第4步中不需要计算损失,直接在最后使用阈值保留输出中置信度比较大(>0.7的),而置信度小于0.7的当做背景物体。

总结

DTER基于Transformer和二分匹配损失,通过直接集合预测来进行目标检测,提出了一种端到端目标检测新方法。DETR不需要预定义的先验anchor,也不需要NMS的后处理策略,就可以实现端到端的目标检测。并且DETR易于实现,具有灵活的架构,易于扩展到其他的复杂任务,例如全景分割。

DETR在大目标检测上的性能优势明显,而在小目标上稍差,而且基于match的损失导致学习很难收敛。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/950705.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

服务器双网卡NCCL通过交换机通信

1、NCCL变量设置 export CUDA_DEVICE_MAX_CONNECTIONS1 export NCCL_SOCKET_IFNAMEeno2 export NCCL_IB_DISABLE0 #export NCCL_NETIB export NCCL_IB_HCAmlx5_0,mlx5_1 export NCCL_IB_GID_INDEX3 export NCCL_DEBUGINFOGPUS_PER_NODE4MASTER_ADDR192.168.1.2 MASTER_PORT600…

B树及其Java实现详解

文章目录 B树及其Java实现详解一、引言二、B树的结构与性质1、节点结构2、性质 三、B树的操作1、插入操作1.1、插入过程 2、删除操作2.1、删除过程 3、搜索操作 四、B树的Java实现1、节点类实现2、B树类实现 五、使用示例六、总结 B树及其Java实现详解 一、引言 B树是一种多路…

数据分析思维(八):分析方法——RFM分析方法

数据分析并非只是简单的数据分析工具三板斧——Excel、SQL、Python,更重要的是数据分析思维。没有数据分析思维和业务知识,就算拿到一堆数据,也不知道如何下手。 推荐书本《数据分析思维——分析方法和业务知识》,本文内容就是提取…

微信小程序用的SSL证书有什么要求吗?

微信小程序主要建立在手机端使用,然而手机又涉及到各种系统及版本,所以对SSL证书也有要求,如果要小程序可以安全有效的访问需要满足以下要求: 1、原厂SSL证书(原厂封)。 2、DV单域名或者DV通配符。 3、兼…

手动安装 Maven 依赖到本地仓库

文章目录 手动安装 Maven 依赖到本地仓库1. 下载所需的 JAR 文件2. 安装 JAR 文件到本地仓库3. 验证安装4. 在项目中使用该依赖 手动安装 Maven 依赖到本地仓库 遇到的问题: idea导入一个新的工程,发现pom文件中的一些依赖死活下载不下来,这…

VSCode Live Server 插件安装和使用

VSCode Live Server是一个由Ritwick Dey开发的Visual Studio Code扩展插件,它提供了一个带有实时重载功能的本地开发服务器。在VSCode中安装和使用Live Server插件进行实时预览和调试Web应用程序。这将大大提高前端开发效率,使网页设计和开发变得更为流畅…

UART串口数据分析

串口基础知识详细介绍: 该链接详细介绍了串并行、单双工、同异步、连接方式 https://blog.csdn.net/weixin_43386810/article/details/127156063 该文章将介绍串口数据的电平变化、波特率计算、脉宽计算以及数据传输量的计算。 捕获工具:逻辑分析仪&…

Internet协议原理

文章目录 考试说明Chapter 0: 本书介绍Chapter 1: Introduction And Overview 【第1章:引言与概述】Chapter 2: Overview Of Underlying Network Technologies 【第2章:底层网络技术的回顾】Chapter 3: Internetworking Concept And Architectural Model…

DeepSeek-V3 通俗详解:从诞生到优势,以及与 GPT-4o 的对比

1. DeepSeek 的前世今生 1.1 什么是 DeepSeek? DeepSeek 是一家专注于人工智能技术研发的公司,致力于打造高性能、低成本的 AI 模型。它的目标是让 AI 技术更加普惠,让更多人能够用上强大的 AI 工具。 1.2 DeepSeek-V3 的诞生 DeepSeek-V…

linux之自动挂载

如果想要实现自动挂载,应该挂在客户端!!!!! 客户端: [rootlocalhost ~]# yum install nfs-utils -y (下载软件) [rootlocalhost ~]# systemctl start nfs-utils.servic…

RHCSA知识点汇总

第0章:Linux基础入门 0.1 什么是计算机 计算机的组成: 控制器:是整个计算机的中枢神经,根据程序要求进行控制,协调计算机各部分工作及内存与外设的访问等。 输入设备:将文字、数据、程序和控制命令等信…

交响曲-24-3-单细胞CNV分析及聚类

CNV概述 小于1kb是常见的插入、移位、缺失等的变异 人体内包含<10% 的正常CNV&#xff0c;我们的染色体数是两倍体&#xff0c;正常情况下&#xff0c;只有一条染色体表达&#xff0c;另一条沉默&#xff0c;当表达的那条染色体发生CNV之后&#xff0c;表达数量就会成倍增加…

【Linux-多线程】POSIX信号量-基于环形队列生产消费模型

POSIX信号量 POSIX信号量和System V信号量作用相同&#xff0c;都是用于同步操作&#xff0c;达到无冲突的访问共享资源的目的。但POSIX可以用于线程间同步 1.快速认识信号量接口 POSIX信号量分为两种类型&#xff1a; 命名信号量&#xff08;Named Semaphores&#xff09;&…

Linux下文件操作相关接口

文章目录 一 文件是什么普通数据文件 二 文件是谁打开的进程用户 三 进程打开文件的相关的接口c语言标准库相关文件接口1. fopen 函数2. fread 函数3. fwrite 函数4. fclose 函数5. fseek 函数 linux系统调用接口1. open 系统调用2. creat 系统调用3. read 系统调用4. write 系…

UE蓝图节点备忘录

获取索引为0的玩家 获取视图缩放 反投影屏幕到世界 获取屏幕上的鼠标位置 对指定的物体类型进行射线检测 判断物体是否有实现某个接口 上面节点的完整应用 通过PlayerControlle获取相机相关数据 从相机处发射射线撞击物体从而获取物体信息 抽屉推拉功能 节点说明 ##门的旋转开关…

玩机搞机基本常识-------列举安卓机型一些不常用的adb联机命令

前面分享过很多 常用的adb命令&#xff0c;今天分享一些不经常使用的adb指令。以作备用 1---查看当前手机所有app包名 adb shell pm list package 2--查看当前机型所有apk包安装位置 adb shell pm list package -f 3--- 清除指定应用程序数据【例如清除浏览器应用的数据】 …

LeetCode【剑指offer】系列(字符串篇)

剑指offer05.替换空格 题目链接 题目&#xff1a;假定一段路径记作字符串path&#xff0c;其中以 “.” 作为分隔符。现需将路径加密&#xff0c;加密方法为将path中的分隔符替换为空格" "&#xff0c;请返回加密后的字符串。 思路&#xff1a;遍历即可。 通过代…

idea java.lang.OutOfMemoryError: GC overhead limit exceeded

Idea build项目直接报错 java: GC overhead limit exceeded java.lang.OutOfMemoryError: GC overhead limit exceeded 设置 编译器 原先heap size 设置的是 700M , 改成 2048M即可

aws(学习笔记第二十二课) 复杂的lambda应用程序(python zip打包)

aws(学习笔记第二十二课) 开发复杂的lambda应用程序(python的zip包) 学习内容&#xff1a; 练习使用CloudShell开发复杂lambda应用程序(python) 1. 练习使用CloudShell CloudShell使用背景 复杂的python的lambda程序会有许多依赖的包&#xff0c;如果不提前准备好这些python的…

conda 批量安装requirements.txt文件

conda 批量安装requirements.txt文件中包含的组件依赖 conda install --yes --file requirements.txt #这种执行方式&#xff0c;一遇到安装不上就整体停止不会继续下面的包安装。 下面这条命令能解决上面出现的不执行后续包的问题&#xff0c;需要在CMD窗口执行&#xff1a; 点…