yolov8obb角度预测原理解析

预测头

ultralytics/nn/modules/head.py

class OBB(Detect):
    """YOLOv8 OBB detection head for detection with rotation models."""

    def __init__(self, nc=80, ne=1, ch=()):
        """Initialize OBB with number of classes `nc` and layer channels `ch`."""
        super().__init__(nc, ch)
        self.ne = ne  # number of extra parameters

        c4 = max(ch[0] // 4, self.ne)
        self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        bs = x[0].shape[0]  # batch size
        angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2)  # OBB theta logits
        # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
        angle = (angle.sigmoid() - 0.25) * math.pi  # [-pi/4, 3pi/4]
        # angle = angle.sigmoid() * math.pi / 2  # [0, pi/2]
        if not self.training:
            self.angle = angle
        x = Detect.forward(self, x)
        if self.training:
            return x, angle
        # return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

        return torch.cat([x, angle], 1).permute(0, 2, 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))

forward 输入值
在这里插入图片描述

self.cv4网路结构

ModuleList(
  (0): Sequential(
    (0): Conv(
      (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (1): Sequential(
    (0): Conv(
      (conv): Conv2d(128, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (2): Sequential(
    (0): Conv(
      (conv): Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )

angle维度14,1,8400

损失函数

pred_angle = pred_angle.permute(0, 2, 1).contiguous()
维度变为14 8400 1

将预测结果转为bboxes

pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle)  # xyxy, (b, h*w, 4)

计算回归损失

loss[0], loss[2] = self.bbox_loss(
                pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
            )

这里的bbox_loss指的是:

self.bbox_loss = RotatedBboxLoss(self.reg_max - 1, use_dfl=self.use_dfl).to(self.device)

接来下看RotatedBboxLoss

    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
        loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum

        # DFL loss
        if self.use_dfl:
            target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.reg_max)
            loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)

        return loss_iou, loss_dfl

两个旋转矩形如何计算IOU:

def probiou(obb1, obb2, CIoU=False, eps=1e-7):
    """
    Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf.

    Args:
        obb1 (torch.Tensor): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
        obb2 (torch.Tensor): A tensor of shape (N, 5) representing predicted obbs, with xywhr format.
        eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.

    Returns:
        (torch.Tensor): A tensor of shape (N, ) representing obb similarities.
    """
    x1, y1 = obb1[..., :2].split(1, dim=-1)
    x2, y2 = obb2[..., :2].split(1, dim=-1)
    a1, b1, c1 = _get_covariance_matrix(obb1)
    a2, b2, c2 = _get_covariance_matrix(obb2)

    t1 = (
        ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
    ) * 0.25
    t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
    t3 = (
        ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
        / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
        + eps
    ).log() * 0.5
    bd = (t1 + t2 + t3).clamp(eps, 100.0)
    hd = (1.0 - (-bd).exp() + eps).sqrt()
    iou = 1 - hd
    if CIoU:  # only include the wh aspect ratio part
        w1, h1 = obb1[..., 2:4].split(1, dim=-1)
        w2, h2 = obb2[..., 2:4].split(1, dim=-1)
        v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
        with torch.no_grad():
            alpha = v / (v - iou + (1 + eps))
        return iou - v * alpha  # CIoU
    return iou

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/756686.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Dison夏令营 Day 02】使用 Python 玩井字游戏

在本文中,我们将介绍使用 Python 语言从零开始创建井字游戏的步骤。 在本文中,我们将介绍使用 Python 语言从零开始创建井字游戏的步骤。 游戏简介 井字游戏是一种双人游戏,在 33 正方形网格上进行。每位玩家轮流占据一个单元格&#xff0c…

CMake(1)基础使用

CMake之(1)基础使用 Author: Once Day Date: 2024年6月29日 一位热衷于Linux学习和开发的菜鸟,试图谱写一场冒险之旅,也许终点只是一场白日梦… 漫漫长路,有人对你微笑过嘛… 全系列文章可参考专栏: Linux实践记录_Once-Day的博客-CSDN博客…

双指针算法第一弹(移动零 复写零 快乐数)

目录 前言 1. 移动零 (1)题目及示例 (2)一般思路 (3)双指针解法 2. 复写零 (1)题目及示例 (2)一般解法 (3)双指针解法 3. 快…

计算机基础知识——C基础+C指针+char类型

指针 这里讲的很细 https://blog.csdn.net/weixin_43624626/article/details/130715839 内存地址:内存中每个字节单位都有一个编号(一般用十六进制表示) 存储类型 数据类型 *指针变量名;int *p; //定义了一个指针变量p,指向的数…

在Redis中使用Lua脚本实现多条命令的原子性操作

Redis作为一个高性能的键值对数据库,被广泛应用于各种场景。然而,在某些情况下,我们需要执行一系列Redis命令,并确保这些命令的原子性。这时,Lua脚本就成为了一个非常实用的解决方案。 问题的提出 假设我们有一个计数…

【深度学习】图形模型基础(2):概率机器学习模型与人工智能

1.引言 1.1.背景 当机器需要从经验中汲取知识时,概率建模成为了一个至关重要的工具。它不仅为理解学习机制提供了理论框架,而且在实际应用中,特别是在设计能够从数据中学习的机器时,概率建模展现出了其独特的价值。概率框架的核…

Power BI可视化表格矩阵如何保持样式导出数据?

故事背景: 有朋友留言询问:自己从Power BI可视化矩阵表格中导出数据时,导出的表格样式会发生改变,需要线下再手动调整,重新进行透视组合成自己想要的格式。 有没有什么办法让表格导出来跟可视化一样? Po…

汽车电子工程师入门系列——CAN 规范系列通读

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不对。非必要不费力证明自己,无利益不试图说服别人,是精神上的节…

SiteSucker Pro for Mac:一键下载整站,轻松备份与离线浏览!

SiteSucker Pro for Mac是一款专为苹果电脑用户设计的网站下载与备份工具🕸️。它以其强大的整站下载能力和用户友好的界面,成为了众多Mac用户备份网站、离线浏览的得力助手💻。 这款软件允许用户一键下载整个网站,包括所有的网页…

Docker(八)-Docker运行mysql8容器实例

1.运行mysql8容器实例并挂载数据卷 -e:配置环境变量 --lower_case_table_names1 设置忽略表名大小写一定要放在镜像之后运行mysql8容器实例之前,先查看是否存在mysql8镜像以及是否存在已运行的mysql实例docker run -d -p 3306:3306 --privilegedtrue -v 【宿主机日…

L03_Redis知识图谱

这些知识点你都掌握了吗?大家可以对着问题看下自己掌握程度如何?对于没掌握的知识点,大家自行网上搜索,都会有对应答案,本文不做知识点详细说明,只做简要文字或图示引导。 Redis 全景图 Redis 知识全景图都包括什么呢?简单来说,就是“两大维度,三大主线”。 Redis …

MySQL连接IDEA(Java Web)保姆级教程

第一步:新建项目(File)->Project 第二步:New Project(JDK最好设置1.8版本与数据库适配,详细适配网请到MySQL官网查询MySQL :: MySQL 8.3 Reference Manual :: Search Results) 第三步:点中MySQLTest(项目名)并连续双击shift键-…

昇思25天学习打卡营第2天|数据集Dataset

学习目标:熟练掌握mindspore.dataset mindspore.dataset中有常用的视觉、文本、音频开源数据集供下载,点赞、关注收藏哦 了解mindspore.dataset mindspore.dataset应用实践 拓展自定义数据集 昇思平台学习时间记录: 一、关于mindspore.dataset minds…

【STM32】在标准库中使用定时器

1.TIM简介 STM32F407系列控制器有2个高级控制定时器、10个通用定时器和2个基本定时器。通常情况下,先看定时器挂在哪个总线上APB1或者APB2,然后定时器时钟需要在此基础上乘以2。 2.标准库实现定时中断 #ifndef __BSP_TIMER_H #define __BSP_TIMER_H#if…

.[emcrypts@tutanota.de].mkp勒索病毒新变种该如何应对?

引言 在数字化时代,随着信息技术的迅猛发展,网络安全问题日益凸显。其中,勒索病毒作为一种极具破坏力的恶意软件,给个人和企业带来了巨大的经济损失和数据安全风险。近期,一种名为“.mkp勒索病毒”的新型威胁开始在网络…

多线程引发的安全问题

前言👀~ 上一章我们介绍了线程的一些基础知识点,例如创建线程、查看线程、中断线程、等待线程等知识点,今天我们讲解多线程下引发的安全问题 线程安全(最复杂也最重要) 产生线程安全问题的原因 锁(重要…

在 Python 中创建列表时,应该写 `[]` 还是 `list()`?

在 Python 中,创建列表有两种写法: # 写法一:使用一对方括号 list_1 []# 写法二:调用 list() list_2 list() 那么哪种写法更好呢? 单从写法上来看,[] 要比 list() 简洁,那在性能和功能方面…

江科大笔记—读写内部闪存FLASH读取芯片ID

读写内部闪存FLASH 右下角是OLED,然后左上角在PB1和PB11两个引脚,插上两个按键用于控制。下一个代码读取芯片ID,这个也是接上一个OLED,能显示测试数据就可以了。 STM32-STLINK Utility 本节的代码调试,使用辅助软件…

[机缘参悟-200] - 对自然、人性、人生、人心、人际、企业、社会、宇宙全面系统的感悟 - 全图解

对自然、人性、人生、人心、人际、企业、社会、宇宙进行全面系统的感悟,是一个极其深邃且复杂的主题。以下是对这些领域的简要感悟: 自然: 自然是人类生存的根基,它充满了无尽的奥秘和美丽。自然界的平衡和循环规律,教…

运算符重载之日期类的实现

接上一篇文章&#xff0c;废话不多说&#xff0c;直接上代码 Date.h #pragma once #include<iostream> using namespace std; #include<assert.h>class Date {//友元函数声明friend ostream& operator<<(ostream& out, const Date& d);friend …