【技术总结】常用指标mAP, mIoU, mDice, mFscore, aAcc 实现

mAP

mAP 全称是 mean Average Precision. 其中 mean 这个操作是在 class 级别上, 因此只需要将所有类别平均即可. 所有需要关注的就是 AP. AP 是 Precision-Recall 曲线和坐标轴围成的面积. 提到曲线可能会感觉比较懵 – 模型的预测对或者不对都是确定的, 哪里来的曲线呢?

想要搞明白为什么有曲线, 得看模型的预测结果. 一般来说, 模型在给出类别的预测的时候, 都会有一个置信度 p p p 表示属于这一类的概率. 因此我们就可以设定一个阈值 t t t, 如果 p > t p>t p>t, 认为属于预测类 L \boldsymbol{L} L, 如果 p ≤ t p \leq t pt 认为不属于 L \boldsymbol{L} L. 因此当我们设定不同的 t t t 的时候, 得到的 Precision 和 Recall 就会不同.

取不同的值, 就可以得到不同的 Precision, 可以写成函数 P = g ( t ) P = g(t) P=g(t), 正式来说 g ( t ) g(t) g(t) 就是 Precision 曲线, 横坐标是阈值 t t t, 纵坐标是对应的 Precision. 同理也可以得到 Recall 曲线, 表示为 R = f ( t ) R = f(t) R=f(t), f ( t ) f(t) f(t) 是单调递减的. 所以 P = g ( f − 1 ( R ) ) P = g(f^{-1}(R)) P=g(f1(R)), 这就是 Precision-Recall 曲线.

数据集只有有限个样本, 这种情况下如何得到 Precision-Recall 曲线呢?
数据集只有有限个样本( N N N), 每个样本都会得到一个置信度 p p p. 按照置信度排序之后, 我们就可以知道当阈值 t t t 精确地设定为 p i p_i pi 的时候对应的 Precision 和 Recall, 这样也就得到了 N N N 个 Precision - Recall 点对. 如下所示:

在这里插入图片描述

下面是实现的代码, 有些 trick 需要仔细理解.

class BBox(object):
    def __init__(self, img_name, bbox, conf):
        self.name = img_name
        self.bbox = bbox 
        self.conf = conf

    
    @staticmethod
    def IoU(bbox_a, bbox_b):
        bbox_a = bbox_a.bbox 
        bbox_b = bbox_b.bbox
        area1 = (bbox_a[0] - bbox_a[2]) * (bbox_a[1] - bbox_b[3])
        area2 = (bbox_b[0] - bbox_b[2]) * (bbox_b[1] - bbox_b[3])

        tx = min(bbox_a[0], bbox_b[0])
        ty = min(bbox_a[1], bbox_b[1]) 
        
        bx = max(bbox_a[2], bbox_b[2])
        by = max(bbox_a[3], bbox_b[3])

        inter_area = (bx - tx) * (by - ty)

        return inter_area / (area1 + area2 - inter_area)


def compute_AP(prds, gts, IoU_threshold=0.5):
    """给定预测和ground truth 计算AP值

    Args:
        prds (List[]): 模型预测结果
        gts (_type_): 真值
        IoU_threshold (float, optional): _description_. Defaults to 0.5.
    """

    # sort the prds using conf
    prds = sorted(prds, key=lambda x: x.conf, reverse=True)

    # collect the gts
    name2gt_idx = dict()
    for i, gt in enumerate(gts):
        name2gt_idx[gt.name] = name2gt_idx.get(gt.name, []) + [i]
    
    # compute the TP and FP
    TP = np.zeros(len(prds))
    FP = np.zeros(len(prds))

    # initial used_gt
    used_gt = np.zeros(len(gts))

    for pred_idx, prd in enumerate(prds):
        gt_idxs = name2gt_idx[prd.name]

        # find the gt with max iou
        max_iou_idx = -1
        max_iou = 0
        for gt_idx in gt_idxs:
            gt = gts[gt_idx]
			
			# 如果这个GT bbox已经使用过, 跳过
            if used_gt[gt_idx]:
                continue 

            iou = BBox.IoU(gt, prd)
            if iou > IoU_threshold and iou > max_iou:
                max_iou = iou 
                max_iou_idx = gt_idx
        
        if max_iou_idx != -1:
            TP[pred_idx] = 1
            # 标记这个GT bbox已经使用, 下次不能匹配这个GT bbox
            used_gt[max_iou_idx] = 1
        else:
            FP[pred_idx] = 1
    
    # compute the AP
    acc_TP = np.cumsum(TP)
    acc_FP = np.cumsum(FP)

    n_total = len(gts)

    Pr = acc_TP / (acc_TP + acc_FP)
    Rc = acc_TP / n_total 

    return _compute_AP(Pr, Rc)


def _compute_AP(Pr, Rc):
    Rc = np.concatenate(([0], Rc, [1]))
    Pr = np.concatenate(([0], Pr, [0]))

    # compute the real Pr
    # Pr 曲线不是严格递减, 因此实际计算只取右边最大值
    for i in range(len(Pr) - 1, 0, -1):
        Pr[i-1] = max(Pr[i-1], Pr[i])
    
    # compute the AP
    index = np.where(Rc[1:] != Rc[:-1])[0]
    
    AP = np.sum(Pr[index + 1] * (Rc[index+1] - Rc[index]))
    return AP

mIoU, mDice, mFscore, aAcc 等计算

其中 m 都是指在类比上的mean, 因此只需要考虑如何计算单个类别即可. 下面是 mmsegmentation 中的实现, 比较简洁高效.

  • 求出每张图的 intersection, union, area_pred, area_label, 后面直接在所有样本上累加得到 total_intersection, total_union, total_area_pred, total_label 再进行计算, 这样可以避免 某张图 intersection 或者 area_pred 等为0的情况.
  • 也就是像素级别求解IoU, 然后类别上平均
def intersect_and_union(pred_label: torch.tensor, label: torch.tensor,
                            num_classes: int, ignore_index: int):
        """Calculate Intersection and Union.

        Args:
            pred_label (torch.tensor): Prediction segmentation map
                or predict result filename. The shape is (H, W).
            label (torch.tensor): Ground truth segmentation map
                or label filename. The shape is (H, W).
            num_classes (int): Number of categories.
            ignore_index (int): Index that will be ignored in evaluation.

        Returns:
            torch.Tensor: The intersection of prediction and ground truth
                histogram on all classes.
            torch.Tensor: The union of prediction and ground truth histogram on
                all classes.
            torch.Tensor: The prediction histogram on all classes.
            torch.Tensor: The ground truth histogram on all classes.
        """

        mask = (label != ignore_index)
        pred_label = pred_label[mask]
        label = label[mask]
		
		# intersect 的值为 [0 - num_classes-1]
		# 值==i的个数是i类别上的intersection
        intersect = pred_label[pred_label == label]
        area_intersect = torch.histc(
            intersect.float(), bins=(num_classes), min=0,
            max=num_classes - 1).cpu()
        area_pred_label = torch.histc(
            pred_label.float(), bins=(num_classes), min=0,
            max=num_classes - 1).cpu()
        area_label = torch.histc(
            label.float(), bins=(num_classes), min=0,
            max=num_classes - 1).cpu()
        area_union = area_pred_label + area_label - area_intersect
        return area_intersect, area_union, area_pred_label, area_label

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/488284.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

字节算法岗二面,凉凉。。。

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学,针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总…

php反序列化刷题1

[SWPUCTF 2021 新生赛]ez_unserialize 查看源代码想到robots协议 看这个代码比较简单 直接让adminadmin passwdctf就行了 poc <?php class wllm {public $admin;public $passwd; }$p new wllm(); $p->admin "admin"; $p->passwd "ctf"; ec…

【第三方登录】Google邮箱

登录谷歌邮箱开发者 https://console.developers.google.com/ 先创建项目 我们用的web应用 设置回调 核心主要&#xff1a; 1.创建应用 2.创建客户端ID 3.设置域名和重定向URL 4.对外公开&#xff0c;这样所有的gmail邮箱 都能参与测试PHP代码实现 引入第三方包 h…

【云能耗管理系统在某大型商场的应用】安科瑞Acrel-EIOT能源物联网平台方案

摘要&#xff1a;依据对上海市某大型商场现场考察的结果&#xff0c;提出通过建设云能耗管理系统的方案来改善商场能耗的管理现状。首先充分搜集建筑信息和设备运行工况&#xff0c;合理设计系统实施方案&#xff0c;解决现场数据采集和传输障碍&#xff0c;完成云能耗管理系统…

常用设计模式介绍

前言 简说设计模式。 文章目录 前言一、设计模式的要素1、设计模式解决的问题2、设计模式分类1&#xff09;创建型设计模式2&#xff09;结构型设计模式3&#xff09;行为型设计模式 二、详细介绍1、创建型设计模式1&#xff09;工厂方法模式2&#xff09;抽象工厂模式3&#x…

【JavaEE】进程是什么?

文章目录 ✍进程的概念✍进程存在的意义✍进程在计算机中的存在形式✍进程调度 ✍进程的概念 每个应⽤程序运⾏于现代操作系统之上时&#xff0c;操作系统会提供⼀种抽象&#xff0c;好像系统上只有这个程序在运⾏&#xff0c;所有的硬件资源都被这个程序在使⽤。这种假象是通…

上位机图像处理和嵌入式模块部署(qmacvisual拟合圆和拟合椭圆)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 前面我们学习了拟合直线&#xff0c;今天继续学习下拟合圆和拟合椭圆。其实除了最后一步不同&#xff0c;两者的逻辑是差不多的。一般都是&#xf…

C语言例4-6:格式字符d的使用例子

代码如下&#xff1a; //格式字符d的使用例子 #include<stdio.h> int main(void) {int num1123;long num2123456;printf("num1%d,num1%5d,num1%-5d,num1%2d\n",num1,num1,num1,num1);//以四种不同格式&#xff0c;输出int型数据num1的值printf("num2%ld,…

Elasticsearch 索引模板、生命周期策略、节点角色

简介 索引模板可以帮助简化创建和二次配置索引的过程&#xff0c;让我们更高效地管理索引的配置和映射。 索引生命周期策略是一项有意义的功能。它通常用于管理索引和分片的热&#xff08;hot&#xff09;、温&#xff08;warm&#xff09;和冷&#xff08;cold&#xff09;数…

【研发管理】产品经理知识体系-战略

导读&#xff1a;了解和掌握产品经理知识体系-战略是产品经理必修课。战略在产品创新管理框架中核心位置。本文概要梳理战略相关知识内容&#xff0c;仅供大家参考。 目录 1、战略定义 1.1 战略金字塔 1.2 战略的层级总表 1.3 战略跟战术的关系 1.4 愿景、使命和价值观​编…

ExoPlayer架构详解与源码分析(12)——Cache

系列文章目录 ExoPlayer架构详解与源码分析&#xff08;1&#xff09;——前言 ExoPlayer架构详解与源码分析&#xff08;2&#xff09;——Player ExoPlayer架构详解与源码分析&#xff08;3&#xff09;——Timeline ExoPlayer架构详解与源码分析&#xff08;4&#xff09;—…

yolov8直接调用zed相机实现三维测距(python)

yolov8直接调用zed相机实现三维测距&#xff08;python&#xff09; 1. 相关配置2. 相关代码3. 实验结果 相关链接 此项目直接调用zed相机实现三维测距&#xff0c;无需标定&#xff0c;相关内容如下&#xff1a; 1.yolov5直接调用zed相机实现三维测距&#xff08;python&#…

2024年哈尔滨工业大学材料科学与工程学院硕士研究生招生复试名单

2024年哈尔滨工业大学材料科学与工程学院硕士研究生招生复试名单 材料科学与工程学院2024年硕士研究生招生考试复试及录取工作方案 &#xff08;含深圳、威海校区&#xff0c;不含航天学院复合材料方向&#xff09; 复试录取名单数据分析: {51412, 50222, 50242, 61121, 50251…

蓝桥杯单片机快速开发笔记——利用定时器计数器设置定时器

一、基本原理 参考本栏http://t.csdnimg.cn/iPHN0 二、具体步骤 三、主要事项 如果使用中断功能记得打开总中断EA 四、示例代码 void Timer0_Isr(void) interrupt 1 { }void Timer0_Init(void) //10毫秒12.000MHz {AUXR & 0x7F; //定时器时钟12T模式TMOD & 0xF0;…

创建linux虚拟机系统:(安装Ubuntu镜像文件,包含语言设置、中文输入法、时间设置)

我下载的是清华大写开源软件镜像站中的ubuntu-20.04.6-desktop-amd64.iso这个镜像文件&#xff0c; 这个文件我下载完成之后没有解压&#xff0c;直接在创建虚拟机的时候选择的压缩包。 地址为&#xff1a;Index of /ubuntu-releases/20.04/ | 清华大学开源软件镜像站 | Tsin…

FastAPI+React全栈开发05 React前端框架概述

Chapter01 Web Development and the FARM Stack 05 The frontend React FastAPIReact全栈开发05 React前端框架概述 Let’s start with a bit of context here. Perhaps the changes in the world of the web are most visible when we talk about the frontend, the part o…

政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

政安晨的个人主页&#xff1a;政安晨 欢迎 &#x1f44d;点赞✍评论⭐收藏 收录专栏: TensorFlow与Keras实战演绎机器学习 希望政安晨的博客能够对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff01; 介绍 通过 Keras&#xff0c;您可以编写自定…

【解析几何】 【多源路径】 【贪心】1520 最多的不重叠子字符串

作者推荐 视频算法专题 本身涉及知识点 解析几何 图论 多源路径 贪心 LeetCode1520. 最多的不重叠子字符串 给你一个只包含小写字母的字符串 s &#xff0c;你需要找到 s 中最多数目的非空子字符串&#xff0c;满足如下条件&#xff1a; 这些字符串之间互不重叠&#xff0…

LeetCode 面试经典150题 392.判断子序列

题目&#xff1a; 给定字符串 s 和 t &#xff0c;判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些&#xff08;也可以不删除&#xff09;字符而不改变剩余字符相对位置形成的新字符串。&#xff08;例如&#xff0c;"ace"是"abcde"…

量子计算新“尺度”:用经典计算机评估复杂量子系统!

未来的量子计算机有望在计算机科学、医疗、商业、化学、物理学等多个领域解决难题&#xff0c;从而超越传统计算机。然而&#xff0c;目前的量子计算机仍存在局限&#xff0c;主要是由于它们固有的错误率。为此&#xff0c;研究者正致力于降低这些错误率。 一种研究量子计算机误…