非最大值抑制(NMS)函数

非最大值抑制(NMS)函数

flyfish

非最大值抑制(Non-Maximum Suppression, NMS)是计算机视觉中常用的一种后处理技术,主要用于目标检测任务。其作用是从一组可能存在大量重叠的候选边界框中,筛选出最具代表性的边界框,即通过置信度分数和重叠区域的过滤,保留最具代表性的边界框。

边界框(Bounding Boxes):一组表示候选目标区域的矩形框,每个框由左上角和右下角的坐标(x1, y1, x2, y2)表示。
置信度分数(Confidence Scores):每个边界框对应的一个置信度分数,表示该框内包含目标的可能性。

执行步骤

初始化:
boxes:输入的边界框列表。
scores:每个边界框对应的置信度得分列表。
confidence_threshold:过滤边界框的最低置信度阈值。
iou_threshold:用于确定边界框是否重叠的 IOU 阈值。

过滤低置信度边界框:
根据 confidence_threshold 过滤掉置信度低于该阈值的边界框。

按置信度排序:
对剩余的边界框按照置信度从高到低排序。

非极大值抑制:
从排序后的列表中选择置信度最高的边界框,并计算其与其他边界框的 Intersection-over-Union (IoU)。
如果 IoU大于 iou_threshold,则移除该边界框(表示重叠太多)。
重复该过程直到处理完所有边界框。

返回结果:
返回保留的边界框的索引。
在这里插入图片描述
可视化 Intersection-over-Union (IoU)

蓝色矩形表示 Box A,红色矩形表示 Box B,绿色矩形表示它们的交集区域,剩余的红色和蓝色是并集区域。
在这里插入图片描述

torchvision.ops.nms 和 cv2.dnn.NMSBoxes 的调用

import numpy as np
import torch
import torchvision.ops as ops
import cv2

# 输入数据
boxes = np.array([
    [100, 100, 210, 210], [220, 220, 320, 330], [300, 300, 400, 400],
    [50, 50, 150, 200], [200, 150, 280, 320], [280, 280, 380, 380],
    [80, 90, 190, 210], [250, 250, 350, 370], [290, 290, 390, 390]
])# (x1, y1, x2, y2)格式
scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
score_threshold = 0.5
nms_threshold = 0.4

def convert_to_xywh(boxes): #opencv用 (x, y, w, h)格式
    """
    将边界框从 (x1, y1, x2, y2) 格式转换为 (x, y, w, h) 格式。
    
    参数:
    - boxes: 形状为 (N, 4) 的数组,其中 N 是边界框的数量
    
    返回:
    - boxes_xywh: 形状为 (N, 4) 的数组,包含转换后的边界框
    """
    boxes_xywh = np.zeros_like(boxes)
    boxes_xywh[:, 0] = boxes[:, 0]  # x
    boxes_xywh[:, 1] = boxes[:, 1]  # y
    boxes_xywh[:, 2] = boxes[:, 2] - boxes[:, 0]  # w
    boxes_xywh[:, 3] = boxes[:, 3] - boxes[:, 1]  # h
    return boxes_xywh

def nms_torchvision(boxes, scores, nms_threshold):
    boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
    scores_tensor = torch.tensor(scores, dtype=torch.float32)
    keep = ops.nms(boxes_tensor, scores_tensor, nms_threshold)
    return keep.numpy()

def nms_opencv(boxes, scores, score_threshold, nms_threshold):
    boxes = convert_to_xywh(boxes)
    indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), score_threshold, nms_threshold)
    return np.array(indices).flatten()

# 调用 NMS
keep_torchvision = nms_torchvision(boxes, scores, nms_threshold)
keep_opencv = nms_opencv(boxes, scores, score_threshold, nms_threshold)

print("使用 torchvision.ops.nms 保留的边界框索引: ", keep_torchvision)
print("使用 cv2.dnn.NMSBoxes 保留的边界框索引: ", keep_opencv)

输出

使用 torchvision.ops.nms 保留的边界框索引:  [0 3 1 7 2 4]
使用 cv2.dnn.NMSBoxes 保留的边界框索引:  [0 3 1 7 2 4]

用纯 NumPy 实现的非最大值抑制(NMS)函数

import numpy as np

def nms(boxes, scores, score_threshold, nms_threshold):
    """单类 NMS 使用 NumPy 实现。"""
    # 过滤掉低置信度的框
    indices = np.where(scores > score_threshold)[0]
    boxes = boxes[indices]
    scores = scores[indices]

    # 如果没有剩余的框,返回空列表
    if len(boxes) == 0:
        return []

    # 提取每个边界框的坐标
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]

    # 计算每个边界框的面积
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    # 根据分数进行排序(从高到低)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(indices[i])
        # 计算当前边界框与其余边界框的交集坐标
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        # 计算交集的宽度和高度
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        # 计算交集面积
        inter = w * h
        # 计算交并比(IOU)
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        # 只保留 IOU 小于阈值的边界框
        inds = np.where(ovr <= nms_threshold)[0]
        order = order[inds + 1]

    return keep

# 示例数据
boxes = np.array([
    [100, 100, 210, 210], [220, 220, 320, 330], [300, 300, 400, 400],
    [50, 50, 150, 200], [200, 150, 280, 320], [280, 280, 380, 380],
    [80, 90, 190, 210], [250, 250, 350, 370], [290, 290, 390, 390]
])
scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
score_threshold = 0.5
nms_threshold = 0.4

# 调用NMS
keep_indices = nms(boxes, scores, score_threshold, nms_threshold)
print("使用 NumPy 实现的 NMS 保留的边界框索引: ", keep_indices)
使用 NumPy 实现的 NMS 保留的边界框索引:  [0, 3, 1, 7, 2, 4]

关于语法的解释

在 NumPy 中,冒号 : 用于数组切片。它们可以用来提取数组的子集、重排数组或选取特定的元素。

示例1

scores.argsort()[::-1]
scores.argsort():返回 scores 中元素的索引数组,这些索引会将 scores 排序。
[::-1]:表示反转数组。
在这个例子中,[::-1] 表示从开始到结束,步长为 -1,因此数组会被反转。这里的两个冒号是为了清楚地表示切片的完整语法 [start:stop:step],其中省略了 start 和 stop,只指定了 step 为 -1。

import numpy as np

scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
sorted_indices = scores.argsort()  # 升序排序的索引
print("sorted_indices:", sorted_indices)

# 反转排序索引(降序排序)
reversed_indices = sorted_indices[::-1]
print("reversed_indices:", reversed_indices)
sorted_indices: [8 5 4 2 7 1 6 3 0]
reversed_indices: [0 3 6 1 7 2 4 5 8]

示例2

boxes[:, 0]
boxes[:, 0]:选取 boxes 数组中第 0 列的所有元素。
: 表示选择所有行,0 表示选择第 0 列。
这段代码的作用是提取 boxes 数组中每个边界框的 x1 坐标(左上角的 x 坐标)。

import numpy as np
boxes = np.array([
    [100, 100, 210, 210],
    [220, 220, 320, 330],
    [300, 300, 400, 400],
    [50, 50, 150, 200]
])

x1 = boxes[:, 0]
print("x1:", x1)
x1: [100 220 300  50]

可视化数据的代码

def plot_boxes(boxes, keep_indices):
    fig, ax = plt.subplots(1, figsize=(12, 12))

    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = box
        width = x2 - x1
        height = y2 - y1

        # 所有输入框用蓝色绘制
        edgecolor = 'blue'
        if i in keep_indices:
            # NMS 保留的框用绿色绘制
            edgecolor = 'green'
        else:
            # 被抑制的框用红色绘制
            edgecolor = 'red'
        
        rect = patches.Rectangle((x1, y1), width, height, linewidth=2, edgecolor=edgecolor, facecolor='none')
        ax.add_patch(rect)

    # 设置坐标范围
    ax.set_xlim(0, np.max(boxes[:, [0, 2]]) + 50)
    ax.set_ylim(0, np.max(boxes[:, [1, 3]]) + 50)
    ax.invert_yaxis()  # 图像坐标系和实际坐标系相反时需要

    plt.show()

# 示例数据
boxes = np.array([
    [100, 100, 210, 210], [220, 220, 320, 330], [300, 300, 400, 400],
    [50, 50, 150, 200], [200, 150, 280, 320], [280, 280, 380, 380],
    [80, 90, 190, 210], [250, 250, 350, 370], [290, 290, 390, 390]
])
scores = np.array([0.9, 0.8, 0.75, 0.85, 0.7, 0.65, 0.82, 0.78, 0.6])
score_threshold = 0.5
nms_threshold = 0.4

# 调用NMS
keep_indices = nms(boxes, scores, score_threshold, nms_threshold)
print("使用 NumPy 实现的 NMS 保留的边界框索引: ", keep_indices)

# 绘图
plot_boxes(boxes, keep_indices)

可视化 Intersection-over-Union (IoU)的代码

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_iou(boxA, boxB):
    fig, ax = plt.subplots(1, figsize=(8, 8))

    # 绘制 Box A
    x1A, y1A, x2A, y2A = boxA
    widthA = x2A - x1A
    heightA = y2A - y1A
    rectA = patches.Rectangle((x1A, y1A), widthA, heightA, linewidth=2, edgecolor='blue', facecolor='blue', label='Box A')
    ax.add_patch(rectA)

    # 绘制 Box B
    x1B, y1B, x2B, y2B = boxB
    widthB = x2B - x1B
    heightB = y2B - y1B
    rectB = patches.Rectangle((x1B, y1B), widthB, heightB, linewidth=2, edgecolor='red', facecolor='red', label='Box B')
    ax.add_patch(rectB)

    # 计算交集
    xx1 = np.maximum(x1A, x1B)
    yy1 = np.maximum(y1A, y1B)
    xx2 = np.minimum(x2A, x2B)
    yy2 = np.minimum(y2A, y2B)

    w = np.maximum(0, xx2 - xx1)
    h = np.maximum(0, yy2 - yy1)
    intersection_area = w * h

    # 计算并集
    areaA = (x2A - x1A) * (y2A - y1A)
    areaB = (x2B - x1B) * (y2B - y1B)
    union_area = areaA + areaB - intersection_area

    # 计算 IoU
    iou = intersection_area / union_area

    # 绘制交集
    if w > 0 and h > 0:
        rect_intersection = patches.Rectangle((xx1, yy1), w, h, linewidth=2, edgecolor='green', facecolor='green', linestyle='--', label='Intersection')
        ax.add_patch(rect_intersection)

    # 显示图例
    handles, labels = ax.get_legend_handles_labels()

    plt.legend(handles=handles)

    plt.xlim(0, 500)
    plt.ylim(0, 500)
    plt.gca().set_aspect('equal', adjustable='box')
    plt.title(f'IoU = {iou:.2f}')
    plt.show()

# 示例框
boxA = [100, 100, 300, 300]
boxB = [200, 200, 400, 400]

plot_iou(boxA, boxB)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/749836.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

从CVPR 2024看域适应、域泛化最新研究进展

域适应和域泛化一直以来都是各大顶会的热门研究方向。 域适应指&#xff1a;当我们在源域上训练的模型需要在目标域应用时&#xff0c;如果两域数据分布差异太大&#xff0c;模型性能就有可能降低。这时可以利用目标域的无标签数据&#xff0c;通过设计特定方法减小域间差异&a…

thinksboard 新建子类菜单

新建需要的文件 打开bz-routing.module.ts文件&#xff0c;设置bzRoutes&#xff0c;为下面使用 import { Injectable, NgModule } from angular/core; import { Resolve, RouterModule, Routes } from angular/router; import { Authority } from shared/models/authority.en…

【创建者模式-工厂模式】

简单工厂模式 &#xff08;也称为静态工厂模式&#xff09;由一个工厂对象负责创建所有产品类的实例。客户端通过传入一个参数给工厂类来请求创建哪种产品类的实例。这种模式的优点在于客户端不需要知道具体的产品类&#xff0c;只需要知道对应的参数即可。缺点是当需要添加新…

redis复习

redis知识点 redis持久化redis 订阅发布模式redis主从复制哨兵模式redis雪崩&#xff0c;穿透缓存击穿&#xff08;请求太多&#xff0c;缓存过期&#xff09;缓存雪崩 redis持久化 redis是内存数据库&#xff0c;持久化有两种方式&#xff0c;一种是RDB&#xff08;redis dat…

【解决方案】你必须要知道的~前端九种跨域方式实现原理(完整版)

前言 前后端数据交互经常会碰到请求跨域&#xff0c;什么是跨域&#xff0c;以及有哪几种跨域方式&#xff0c;这些问题通常出现在Web开发中&#xff0c;当浏览器执行脚本发起请求到不同的域名、协议或端口时&#xff0c;出于安全考虑&#xff0c;浏览器会限制这种跨源HTTP请求…

Redis数据库(六):主从复制和缓存穿透及雪崩

目录 一、Redis主从复制 1.1 概念 1.2 主从复制的作用 1.3 实现一主二从 1.4 哨兵模式 1.4.1 哨兵的作用 1.4.2 哨兵模式的优缺点 二、Redis缓存穿透和雪崩 2.1 缓存穿透——查不到 2.1.1 缓存穿透解决办法 2.2 缓存击穿 - 量太大&#xff0c;缓存过期 2.2.1 缓存…

拍照就用华为Pura 70系列,后置真实感人像轻松出片!

平时喜欢用手机记录生活的人是不是总有个烦恼&#xff0c;想要拍出媲美单反的完美人像&#xff0c;又怕照片失真&#xff0c;经过近期对手机摄影的探索&#xff0c;我发现了华为Pura70系列的真实感人像之美&#xff0c;它给予每个热爱生活的人直面镜头的自信&#xff0c;记录真…

毕业季留念,就该这样记录下来

毕业季来啦&#xff01;这个季节总是充满了不舍和期待&#xff0c;就像夏天里的冰淇淋&#xff0c;甜蜜中带着一丝丝凉意。在这个特别的时刻&#xff0c;我想和大家分享一款陪伴我记录青春点滴的神器——nova 12 Ultra 手机。 要说自拍&#xff0c;我可是个“资深玩家”。以前…

以算筑基,以智赋能 | Gooxi受邀出席2024中国智算中心全栈技术大会

6月25日&#xff0c;2024中国智算中心全栈技术大会暨展览会、第5届中国数据中心绿色能源大会暨第10届中国&#xff08;上海&#xff09;国际数据中心产业展览会在上海新国际博览中心隆重召开。Gooxi受邀参与并携最新服务器产品以及解决方案亮相展会&#xff0c;吸引众多行业领袖…

基于MATLAB仿真设计无线充电系统

通过学习无线充电相关课程知识&#xff0c;通过课程设计无线充电系统&#xff0c;将所学习的WPT&#xff0c;DC-DC&#xff0c;APFC进行整合得到整个无线充电系统&#xff0c;通过进行仿真研究其系统特性&#xff0c;完成我们预期系统功能和指标。 以功率器件为基本元件&#x…

【人工智能学习之图像操作(二)】

【人工智能学习之图像操作&#xff08;二&#xff09;】 图像上的运算图像混合按位运算 图像的几何变换仿射变换透视变换膨胀操作腐蚀操作开操作闭操作梯度操作礼帽操作黑帽操作 图像上的运算 图像上的算术运算&#xff0c;加法&#xff0c;减法&#xff0c;图像混合等。 加减…

Profibus协议转Modbus协议网关模块在船舶中的应用

一、背景 在当今数字化快速发展的时代&#xff0c;船舶作为重要的交通工具之一&#xff0c;也在不断追赶着科技的步伐&#xff0c;实现自身的智能化升级。而在这个过程中&#xff0c;Profibus转Modbus网关&#xff08;XD-MDPB100&#xff09;作为关键的一环&#xff0c;扮演着…

05 Shell编程之免交互

目录 5.1 Here Document 免交互 5.1.1 Here Document 概述 5.1.2 Here Document 免交互 1. 通过read命令接收输入并打印 5.1.3 Here Document变量设定 5.1.4 Here Document 格式控制 (1)关闭变量替换的功能。 (2)去掉每行之前的TAB字符。 5.1.5 Here Document 多行注释…

前端写代码真的有必要封装太好么?

前言 封装、代码复用、设计模式…… 这些都是方法&#xff0c;业务才是目的。技术始终是为业务服务的。能够满足业务需求&#xff0c;并且用起来舒服的&#xff0c;都是好方法。 不存在一套适用于所有项目的最佳代码组织方法&#xff0c;你需要结合业务&#xff0c;去不断地…

cad报错:由于找不到vcruntime140.dll无法继续执行代码

在现代的工程设计中&#xff0c;计算机辅助设计&#xff08;CAD&#xff09;软件已经成为了工程师们不可或缺的工具。然而&#xff0c;在使用CAD软件的过程中&#xff0c;有时我们会遇到一些问题&#xff0c;其中之一就是“找不到vcruntime140.dll”的错误提示。本文将详细介绍…

鸿蒙期末项目(2)

主界面 主界面和商店详情界面参考如下设计图&#xff08;灵感严重匮乏&#xff09; 简单起见&#xff0c;将整个app分为4个布局&#xff0c;分别是主界面、搜索界面、购物车界面&#xff0c;以及个人界面。 所以在app中也需要使用tab组件进行分割&#xff0c;且需要通过tabBa…

安装Flask

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 大多数Python包都使用pip实用工具安装&#xff0c;使用Virtualenv创建虚拟环境时会自动安装pip。激活虚拟环境后&#xff0c;pip 所在的路径会被添加…

离散傅里叶变化

傅里叶变换 对傅里叶变换了解不是很清楚的朋友推荐一下这个帖子&#xff0c;讲得很详细 傅里叶变换 源码 先看源码链接 #include "opencv2/core.hpp" #include "opencv2/imgproc.hpp" #include "opencv2/imgcodecs.hpp" #include "open…

FuTalk设计周刊-Vol.026

&#x1f525;&#x1f525;AI漫谈 热点捕手&#x1f525;&#x1f525; 1、Hotshot-XL AI文本转GIF Hotshot-XL 是一种 AI 文本转 GIF 模型&#xff0c;经过训练可与Stable Diffusion XL一起使用。能够使用任何现有或新微调的 SDXL 模型制作 GIF。 网页体验 网页http://htt…

git 初基本使用-----------笔记(结合idea)

Git命令 下载git 打开Git官网&#xff08;git-scm.com&#xff09;&#xff0c;根据自己电脑的操作系统选择相应的Git版本&#xff0c;点击“Download”。 基本的git命令使用 可以在项目文件下右击“Git Bash Here” &#xff0c;也可以命令终端下cd到指定目录执行初始化命令…