DEtection TRansformer (DETR)与YOLO在目标检测方面的比较

1. 概述

计算机视觉中的目标检测是一个复杂而有趣的领域,它涉及到让计算机能够识别图像中的物体,并确定它们的位置。下面是DETR和YOLO这两种目标检测方法简单比较:

1.1 YOLO

YOLO是一种非常流行的目标检测算法,它的核心思想是将目标检测任务视为一个回归问题。YOLO将输入图像划分为一个个格子(grid),每个格子负责预测中心点落在该格子内的目标。YOLO会预测边界框(bounding boxes)的位置以及边界框内物体的类别。

YOLO的特点包括:

  • 速度快:YOLO的设计使其能够非常快速地进行目标检测,这使得它非常适合实时应用。
  • 整体性:YOLO将整个图像作为一个整体来处理,这有助于它捕捉到图像中的上下文信息。
  • 版本迭代:YOLO已经发展到了多个版本(如YOLOv1, YOLOv2, YOLOv3, YOLOv4, YOLOv5等),每个版本都在性能和速度上有所提升。

1.2 DETR

DETR是一种较新的基于transformer的目标检测方法。与YOLO不同,DETR不依赖于卷积神经网络(CNN)来提取特征,而是使用transformer架构来直接预测目标的类别和位置。

DETR的特点包括:

  • 基于transformer:DETR利用了transformer模型的自注意力机制,这使得它能够捕捉到全局上下文信息。
  • 端到端:DETR是一个端到端的模型,它直接从图像到边界框和类别标签进行预测,无需额外的锚框或复杂的后处理步骤。
  • 灵活性:由于transformer架构的灵活性,DETR可以容易地扩展到其他任务,如实例分割等。

2、算法比较

自2012年以来,计算机视觉经历了一场由卷积神经网络(CNN)和深度学习架构带来的革命性变革。其中值得注意的架构包括AlexNet(2012年)、GoogleNet(2014年)、VGGNet(2014年)和ResNet(2015年),它们包含了许多卷积层以提高图像分类的准确性。尽管图像分类任务涉及给整个图像分配标签,例如将一张图片分类为狗或汽车,但目标检测不仅识别图像中的内容,还精确地指出每个物体在图像中的位置。

原始的YOLO(2015年)论文在发布时在实时目标检测方面是一个突破,并且仍然是实际视觉应用中最常用的模型之一。它将检测过程从两到三阶段过程(即,R-CNN,Fast R-CNN)转变为单阶段卷积阶段,并在准确性和速度方面超越了所有最先进的目标检测方法。原始论文中的模型架构随着时间的推移发生了变化,通过添加不同的手工设计特征来提高模型的准确性。以下是YOLO前三个版本的概述及其差异。

YOLO v1 (2015)

YOLO v1是原始版本,为后续迭代奠定了基础。它使用单个深度卷积神经网络(CNN)来预测边界框和类别概率。YOLO v1将输入图像划分为网格,并在每个网格单元进行预测。每个单元负责预测固定数量的边界框及其相应的类别概率。这个版本以令人印象深刻的速度实现了实时目标检测,但在检测小物体和准确定位重叠物体方面存在一些限制。
在这里插入图片描述

YOLO v2 (2016)

YOLO v2解决了原始YOLO模型的一些限制。它引入了锚框,这有助于更好地预测不同大小和纵横比的边界框。YOLO v2使用了一个更强大的后端网络Darknet-19,并且不仅在原始数据集(PASCAL VOC)上训练,还在COCO数据集上训练,这显著增加了可检测类别的数量。锚框和多尺度训练的结合有助于提高小物体的检测性能。
在这里插入图片描述

YOLO v3 (2018)

YOLO v3进一步提高了目标检测的性能。这个版本引入了特征金字塔网络的概念,具有多个检测层,允许模型在不同的尺度和分辨率下检测物体。YOLO v3使用了一个更大的网络架构,有53个卷积层,称为Darknet-53,这提高了模型的表示能力。YOLO v3在三个不同的尺度上使用检测:13x13、26x26和52x52网格。每个尺度每个网格单元预测不同数量的边界框。
在这里插入图片描述

我们在预测多少边界框??在416 x 416的分辨率下,YOLO v1预测7 x 7 = 49个框。YOLO v2预测了13 x 13 x 5 = 845个框。对于YOLO v2,在每个网格单元,使用5个锚点检测5个框。另一方面,YOLO v3在3个不同的尺度上预测框。对于同样大小为416 x 416的图像,预测的框数为13 x 13 x 3 + 26 x 26 x 3 + 52 x 52 x 3 = 10,647。非极大值抑制(NMS),一种后处理技术,用于过滤冗余和重叠的边界框预测。在NMS算法中,首先,低于某个置信度分数的框从预测列表中删除。然后,置信度分数最高的预测被视为“当前”预测,所有置信度分数较低且与“当前”预测的IoU高于某个阈值(例如,0.5)的其他预测被标记为冗余并被抑制。有关在PyTorch中实现NMS,请参阅这个YouTube视频。

DETR

DETR (DEtection TRansformer)是一种相对较新的目标检测算法,由Facebook AI Research (FAIR)的研究人员在2020年引入。它基于transformer架构,这是一种强大的序列到序列模型,已被用于各种自然语言处理任务。传统的目标检测器(即,R-CNN和YOLO)复杂,经历了多次变化,并依赖于手工设计的组件(即,NMS)。另一方面,DETR是一个直接的集合预测模型,使用transformer编码器-解码器架构一次性预测所有物体。这种方法比传统目标检测器更简单、更高效,并在COCO数据集上实现了可比的性能。

DETR架构简单,由三个主要部分组成:用于特征提取的CNN后端(即,ResNet)、transformer编码器-解码器和用于最终检测预测的前馈网络(FFN)。后端处理输入图像并生成激活图。transformer编码器降低通道维度并应用多头自注意力和前馈网络。transformer解码器使用N个物体嵌入的并行解码,并独立预测箱子坐标和类别标签,使用物体查询。DETR利用成对关系,从整个图像上下文中受益,共同推理所有物体。
在这里插入图片描述

3、论文

以下代码(取自DETR的官方GitHub仓库)定义了这个DETR模型的前向传递,它通过包括卷积后端和transformer网络在内的各个层处理输入数据。我在代码中包含了每个网络层的输出形状,以了解所有的数据转换。

class DETRdemo(nn.Module):
    def __init__(self, num_classes, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        # 2. 创建ResNet-50后端
        self.backbone = resnet50()
        del self.backbone.fc        # 创建转换层
        self.conv = nn.Conv2d(2048, hidden_dim, 1)
        # 3. 创建默认的PyTorch transformer
        self.transformer = nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        # 4. 预测头,一个额外的类用于预测非空插槽
        # 注意,在基线DETR中线性_bbox层是3层MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)
        # 5. 输出位置编码(物体查询)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
        # 空间位置编码
        # 注意,在基线DETR中我们使用正弦位置编码
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # 通过ResNet-50向上传播到平均池化层
        # 输入:torch.Size([1, 3, 800, 1066])
        x = self.backbone.conv1(inputs)    # torch.Size([1, 64, 400, 533])
        x = self.backbone.bn1(x)           # torch.Size([1, 64, 400, 533])
        x = self.backbone.relu(x)          # torch.Size([1, 64, 400, 533])
        x = self.backbone.maxpool(x)       # torch.Size([1, 64, 200, 267])
        x = self.backbone.layer1(x)        # torch.Size([1, 256, 200, 267])
        x = self.backbone.layer2(x)        # torch.Size([1, 512, 100, 134])
        x = self.backbone.layer3(x)        # torch.Size([1, 1024, 50, 67])
        x = self.backbone.layer4(x)        # torch.Size([1, 2048, 25, 34])
        # 从2048转换为256个特征平面供transformer使用
        h = self.conv(x)                   # torch.Size([1, 256, 25, 34])
        # 构建位置编码
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1) # torch.Size([850, 1, 256])
        src = pos + 0.1 * h.flatten(2).permute(2, 0, 1)  # torch.Size([850, 1, 256])
        target = self.query_pos.unsqueeze(1)    # torch.Size([100, 1, 256])
        # 通过transformer传播
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1), self.query_pos.unsqueeze(1).transpose(0, 1)) # torch.Size([1, 100, 256])
        linear_cls = self.linear_class(h)        # torch.Size([1, 100, 92])
        liner_bbx = self.linear_bbox(h).sigmoid()  # torch.Size([1, 100, 4])
        # 最后将transformer输出投影到类标签和边界框
        return {'pred_logits': linear_cls, 'pred_boxes': linear_bbx}

以下是代码的逐步解释:

初始化:__init__方法定义了DETR模块的结构。它接受几个超参数作为输入,包括类别数量(num_classes)、隐藏维度(hidden_dim)、注意力头数(nheads)以及编码器和解码器的层数(num_encoder_layers和num_decoder_layers)。

后端和转换层:代码创建了一个ResNet-50后端(self.backbone)并移除了其全连接(fc)层,因为检测时不会使用它。conv层(self.conv)被添加以将后端的输出从2048个通道转换为hidden_dim个通道。

transformer:使用nn.Transformer类(self.transformer)创建了一个PyTorch transformer。这个transformer将处理模型的编码器和解码器部分。根据提供的超参数设置编码器和解码器层的数量以及其他参数。

预测头:模型定义了两个线性层用于预测:self.linear_class预测类别logits。为了预测非空插槽,增加了一个额外的类别,因此是num_classes + 1。self.linear_bbox预测边界框的坐标。应用了.sigmoid()函数以确保边界框坐标在[0, 1]范围内。

位置编码:位置编码对于基于transformer的模型至关重要。模型定义了查询位置编码(self.query_pos)和空间位置编码(self.row_embed和self.col_embed)。这些编码帮助模型理解不同元素之间的空间关系。

模型产生100个有效预测。我们只保留输出概率高于特定限制的输出预测,并丢弃所有其他预测。

4、示例

在这一部分,我展示了一个来自我的GitHub仓库的示例项目,其中我使用了DETR和YOLO模型对实时视频流进行了处理。这个项目的目标是比较DETR在实时视频流上的性能与YOLO(这是大多数实时应用中的事实上的模型)的性能。下面的server.py脚本使用了来自Ultraalytics的YOLO v8和来自torch hub的预训练DETR模型。

import torch
from ultralytics import YOLO
import cv2
from dataclasses import dataclass
import time
from utils.functions import plot_results, rescale_bboxes, transform
from utils.datasets import LoadWebcam, LoadVideo
import logging
logging.basicConfig(level=logging.DEBUG,
                    format="%(asctime)s - %(levelname)s - %(message)s")
@dataclass
class Config:
    source: str = "assets/walking_resized.mp4"
    view_img: bool = False
    model_type: str = "detr_resnet50"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    skip: int = 1
    yolo: bool = True
    yolo_type = "yolov8n.pt"

class Detector:
    def __init__(self):
        self.config = Config()
        self.device = self.config.device
        if self.config.source == "0":
            logging.info("Using stream from the webcam")
            self.dataset = LoadWebcam()
        else:
            logging.info("Using stream from the video file: " + self.config.source)
            self.dataset = LoadVideo(self.config.source)
        self.start = time.time()
        self.count = 0

    def load_model(self):
        if self.config.yolo:
            if self.config.yolo_type is None or self.config.yolo_type == "":
                raise ValueError("YOLO model type is not specified")
            model = YOLO(self.config.yolo_type)
            logging.info(f"YOLOv8 Inference using {self.config.yolo_type}")
        else:
            if self.config.model_type is None or self.config.model_type == "":
                raise ValueError("DETR model type is not specified")
            model = torch.hub.load("facebookresearch/detr", self.config.model_type, pretrained=True).to(self.device)
            model.eval()
            logging.info(f"DETR Inference using {self.config.model_type}")
        return model

    def detect(self):
        model = self.load_model()
        for img in self.dataset:
            self.count += 1
            if self.count % self.config.skip != 0:
                continue
            if not self.config.yolo:
                im = transform(img).unsqueeze(0).to(self.device)
                outputs = model(im)
                # 只保留置信度0.7+的预测
                probas = outputs["pred_logits"].softmax(-1)[0, :, :-1]
                keep = probas.max(-1).values > 0.9
                bboxes_scaled = rescale_bboxes(outputs["pred_boxes"][0, keep].to("cpu"), img.shape[:2])
            else:
                outputs = model(img)
            logging.info(f"FPS: {self.count / self.config.skip / (time.time() - self.start)}")
            # print(f"FPS: {self.count / self.skip / (time.time() - self.start)}")
            if self.config.view_img:
                if self.config.yolo:
                    annotated_frame = outputs[0].plot()
                    cv2.imshow("YOLOv8 Inference", annotated_frame)
                    if cv2.waitKey(1) & 0xFF == ord("q"):
                        break
                else:
                    plot_results(img, probas[keep], bboxes_scaled)
        logging.info("************************* Done *****************************")

if __name__ == "__main__":
    detector = Detector()
    detector.detect()

server.py脚本负责从摄像头、IP摄像机或本地视频文件等来源获取数据。这个来源可以在server.py的config数据类中修改。性能评估显示,使用yolov8m.pt模型时,在Tesla T4 GPU上达到了每秒55帧(FPS)的惊人处理速度。另一方面,使用detr_resnet50模型的结果是每秒15 FPS的处理速度。

5、结论

YOLO是实时检测应用的绝佳选择,它专注于速度,适用于视频分析和实时目标跟踪等应用。另一方面,DETR在需要提高准确性和处理物体之间复杂交互的任务中表现出色,这在医学成像、细粒度目标检测以及检测质量优于实时处理速度的场景中可能特别重要。然而,重要的是要认识到,DETR的一个新迭代——被称为RT-DETR或实时DETR——在2023年发布,声称在速度和准确性方面都优于类似规模的所有YOLO检测器。这项创新虽然在这篇博客中没有涵盖,但它强调了这个领域的动态性质,以及根据特定应用需求进一步细化YOLO和DETR选择的潜力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/739074.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IDEA 2024.01版本 git分支merge合并

使用idea工具来进行merge合并 1、拉取远端分支信息 2、我的分支是sprint-240627,我要将test分支合并到我这个分支上 找到test分支 3、选择【Merge origin/test into sprint-240627】 从test合并到我们要合并得分支上,结束 4、如果有冲突,就解决冲突即可…

5. zabbix分布式监控

zabbix分布式监控 一、zabbix分布式监控二、zabbix分布式监控部署1、环境描述2、zabbix proxy的部署2.1 安装zabbix proxy相关的软件2.2 创建proxy需要的库、导入表2.3 编辑zabbix proxy配置文件,指定数据库连接2.4 启动zabbix proxy 3、在zabbix server添加代理4、…

基于 ROS 的 Terraform 托管服务轻松部署文本转语音系统 ChatTTS

介绍 ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。ChatTTS webUI & API 为 ChatTTS 提供了网页界面和API服务。 资源编排服务(Resource Orc…

竞赛选题 python opencv 深度学习 指纹识别算法实现

1 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 python opencv 深度学习 指纹识别算法实现 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:4分创新点:4分 该项目较为新颖…

.locked勒索病毒详解 | 防御措施 | 恢复数据

引言 在数字化飞速发展的今天,我们享受着信息技术带来的便捷与高效,然而,网络安全问题也随之而来,且日益严重。其中,勒索病毒以其狡猾的传播方式和巨大的破坏性,成为了网络安全领域中的一大难题。.locked勒…

创新、引领、发展——SAMPE中国2024年会在京盛大开幕

绿树阴浓夏日长,在这个色彩缤纷的季节,SAMPE中国2024年会暨第十九届国际先进复合材料制品原材料、工装及工程应用展览会在中国国际展览中心(北京朝阳馆)隆重开幕。新老朋友共聚一堂,把酒话桑麻。 为期4天的国际学术会…

TensorRT-LLM加速框架的基本使用

TensorRT-LLM是英伟达发布的针对大模型的加速框架,TensorRT-LLM是TensorRT的延申。TensorRT-LLM的GitHub地址是 https://github.com/NVIDIA/TensorRT-LLM 这个框架在0.8版本有一个比较大的更新,原先的逻辑被统一了,所以早期的版本就不介绍了…

使用鸿蒙HarmonyOs NEXT 开发b站的卡片效果 手把手教学

资源准备: 需要4张图片:分别是页面图,播放图标,评论图标,更多图标 1.实现效果显示: 2.教学视频: 使用鸿蒙HarmonyOs NEXT 开发b站卡片_哔哩哔哩_bilibilihttps://www.bilibili.com/video/BV1…

FPGA的基础仿真项目--七段数码管设计显示学号

一、设计实验目的 1. 了解数码管显示模块的工作原理。 2. 熟悉VHDL 硬件描述语言及自顶向下的设计思想。 3. 掌握利用FPGA设计6位数码管扫描显示驱动电路的方法。 二、实验设备 1. PC机 2.Cyclone IV FPGA开发板 三、扫描原理 下图所…

git检查别人提交的PR(pull requests)并在本地验证,然后合并

可以看官方流程:Checking out pull requests locally - GitHub Docs 当别人给你的开源仓库提交了pull request,你该怎么检查别人提交的代码是否可用,然后合并上去呢?今天我就遇到了,就在前不久开源项目douyin-live失败…

Day5 —— 电商日志数据分析项目

项目二 _____(电商日志数据分析项目) 引言需求分析详细思路统计页面浏览量Map阶段Reduce阶段 日志的ETL操作Map阶段Reduce阶段 统计各个省份的浏览量Map阶段Reduce阶段 具体步骤统计页面浏览量日志的ETL操作统计各个省份的浏览量工具类(utils…

mac鼠标和触摸屏单独设置滚动方向

引言:mac很好用,但是外接鼠标的滚动方向和win不一样,总有点不习惯。于是想要设置一下,当打开设置,搜索鼠标时,将“自然滚动”取消,就可以更改了。 问题:但触摸屏又不好用了。 原因&a…

无线麦克风哪个好?分享口碑最好的麦克风品牌

在这个自媒体时代,给了普通人很多的机会,尤其短视频的兴起更是让无数热情,有创作之心的人跃跃欲试。于是乎越来越多的人纷纷拿起了手机到各个平台去展示自己的才华,或者通过vlog记录分享自己的简单生活。可是在分享和创作的输出时…

ESP32 esp-idf esp-adf环境安装及.a库创建与编译

简介 ESP32 功能丰富的 Wi-Fi & 蓝牙 MCU, 适用于多样的物联网应用。使用freertos操作系统。 ESP-IDF 官方物联网开发框架。 ESP-ADF 官方音频开发框架。 文档参照 https://espressif-docs.readthedocs-hosted.com/projects/esp-adf/zh-cn/latest/get-started/index.…

Spring底层原理之bean的加载方式一 用XML方式声明bean 自定义bean及加载第三方bean 2024详解

目录 用XML方式声明bean 首先我们创建一个空的java工程 我们要导入一个spring的依赖 注意在maven工程里瞅一眼 我们创建一个业务层接口 还有四个实现类 我们最初的spingboot生命bean的方式是通过xml声明 我们在resources文件夹下创建一个配置文件 我们书写代码 首先初…

移动硬盘盒:便携与交互的完美结合 PD 充电IC

在数字化时代的浪潮中,数据已成为我们生活中不可或缺的一部分。随着数据的不断增长,人们对于数据存储的需求也在不断增加。传统的存储设备如U盘、光盘等,虽然具有一定的便携性,但在容量和稳定性方面往往难以满足现代人的需求。而移…

若依框架下拉单选框根据js动态加载,如何使select2的下拉搜素功能同时生效(达到select下拉框的样式不变的效果)

直接上代码,不废话 $(select[name"sealType"]).change(function (event) {let value event.target.valuequeeryDeptListBySealType(value)})// 获取科目信息function queeryDeptListBySealType(value){$.ajax({type: "post",url: prefix &quo…

竞赛选题 python+opencv+深度学习实现二维码识别

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 pythonopencv深度学习实现二维码识别 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:3分工作量:3分创新点:3分 该项目较为新颖&…

Mac提示此电脑不能读取您插的磁盘的原因,Mac磁盘无法读取内容怎么处理

为了能在不同设备中快速传输大容量的文件,我们常常会使用到外接磁盘进行文件的传输。但由于各种原因,比如硬件、文件系统格式等问题,Mac电脑插磁盘会出现无法读取的问题。本文会介绍Mac提示此电脑不能读取您插的磁盘的原因,以及Ma…

基于Java协同过滤算法的电影推荐系统设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码数据库🌟 感兴趣的可以先收藏起来,…