混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算

1.背景介绍

在训练的模型的时候,需要评价模型的好坏,就涉及到混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算。

2、混淆矩阵的概念

所谓的混淆矩阵如下表所示:

TP:真正类,真的正例被预测为正例

FN:假负类,样本为正例,被预测为负类

FP:假正类 ,原本实际为负,但是被预测为正例

TN:真负类,真的负样本被预测为负类。

从混淆矩阵当中,可以得到更高级的分类指标:Accuracy(准确率),Precision(查准率),Recall(查全率),Specificity(特异性),Sensitivity(灵敏度)。

3. 常用的分类指标

3.1 Accuracy(准确率)

不管是哪个类别,只要预测正确,其数量都放在分子上,而分母是全部数据量。常用于表示模型的精度,当数据类别不平衡时,不能用于模型的评价。

Accuracy=\frac{TP+TN}{TP+FN+FN+TN}

3.2 Precision(查准率)

即所有预测为正的样本中,预测正确的样本的所占的比重。

Precision = \frac{TP}{TP+FP}

3.3  Recall(查全率)

真实的为正的样本,被正确检测出来的比重。

Recall=\frac{TP}{TP+FN}

3.4 Specificity(特异性)

特异性指标,也称 负正类率(False Positive Rate, FPR),计算的是模型错识别为正类的负类样本占所有负类样本的比例,一般越低越好。

FPR = \frac{FP}{TN+FP}

3.5 DSC(Dice coefficient)

Dice系数,是一种相似性度量,度量二进制图像分割的准确性。

如图所示红色的框的区域时Groudtruth,而蓝色的框为预测值Prediction。

DSC=\frac{2\left | G\sqcap P \right |}{\left | p \right |+\left | G \right |}

3.6 IoU(交并比)

IoU=\frac{p\sqcap G}{p\bigsqcup G}

3.7 Sensitivity(灵敏度)

反应的时预测正确的区域在Groundtruth中所占的比重。

SEN=\frac{\left | p \left | \sqcap \right |g\right | }{\left | G \right | }

4. 计算程序

ConfusionMatrix 这个类可以直接计算出混淆矩阵

from collections import defaultdict, deque
import datetime
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import errno
import os


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{value:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            # 寻找GT中为目标的像素索引
            k = (a >= 0) & (a < n)
            # 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        # 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        acc_global = torch.diag(h).sum() / h.sum()
        # 计算每个类别的准确率
        acc = torch.diag(h) / h.sum(1)
        # 计算每个类别预测与真实目标的iou
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.mat)

    def __str__(self):
        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
                acc_global.item() * 100,
                ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
                ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
                iu.mean().item() * 100)


class DiceCoefficient(object):
    def __init__(self, num_classes: int = 2, ignore_index: int = -100):
        self.cumulative_dice = None
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.count = None

    def update(self, pred, target):
        if self.cumulative_dice is None:
            self.cumulative_dice = torch.zeros(1, dtype=pred.dtype, device=pred.device)
        if self.count is None:
            self.count = torch.zeros(1, dtype=pred.dtype, device=pred.device)
        # compute the Dice score, ignoring background
        pred = F.one_hot(pred.argmax(dim=1), self.num_classes).permute(0, 3, 1, 2).float()
        dice_target = build_target(target, self.num_classes, self.ignore_index)
        self.cumulative_dice += multiclass_dice_coeff(pred[:, 1:], dice_target[:, 1:], ignore_index=self.ignore_index)
        self.count += 1

    @property
    def value(self):
        if self.count == 0:
            return 0
        else:
            return self.cumulative_dice / self.count

    def reset(self):
        if self.cumulative_dice is not None:
            self.cumulative_dice.zero_()

        if self.count is not None:
            self.count.zeros_()

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.cumulative_dice)
        torch.distributed.all_reduce(self.count)

分类指标的计算

import torch

# SR : Segmentation Result
# GT : Ground Truth

def get_accuracy(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)
    corr = torch.sum(SR==GT)
    tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)
    acc = float(corr)/float(tensor_size)

    return acc

def get_sensitivity(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TP : True Positive
    # FN : False Negative
    TP = ((SR==1)+(GT==1))==2
    FN = ((SR==0)+(GT==1))==2

    SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)     
    
    return SE

def get_specificity(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TN : True Negative
    # FP : False Positive
    TN = ((SR==0)+(GT==0))==2
    FP = ((SR==1)+(GT==0))==2

    SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)
    
    return SP

def get_precision(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TP : True Positive
    # FP : False Positive
    TP = ((SR==1)+(GT==1))==2
    FP = ((SR==1)+(GT==0))==2

    PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)

    return PC

def get_F1(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    SE = get_sensitivity(SR,GT,threshold=threshold)
    PC = get_precision(SR,GT,threshold=threshold)

    F1 = 2*SE*PC/(SE+PC + 1e-6)

    return F1

def get_JS(SR,GT,threshold=0.5):
    # JS : Jaccard similarity
    SR = SR > threshold
    GT = GT == torch.max(GT)
    
    Inter = torch.sum((SR+GT)==2)
    Union = torch.sum((SR+GT)>=1)
    
    JS = float(Inter)/(float(Union) + 1e-6)
    
    return JS

def get_DC(SR,GT,threshold=0.5):
    # DC : Dice Coefficient
    SR = SR > threshold
    GT = GT == torch.max(GT)

    Inter = torch.sum((SR+GT)==2)
    DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6)

    return DC

参考文献:

混淆矩阵的概念-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/352668.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

09. Springboot集成sse服务端推流

目录 1、前言 2、什么是SSE 2.1、技术原理 2.2、SSE和WebSocket 2.2.1、SSE (Server-Sent Events) 2.2.2、WebSocket 2.2.3、选择 SSE 还是 WebSocket&#xff1f; 3、Springboot快速集成 3.1、添加依赖 3.2、创建SSE控制器 3.2.1、SSEmitter创建实例 3.2.2、SSEmi…

K8s-持久化(持久卷,卷申明,StorageClass,StatefulSet持久化)

POD 卷挂载 apiVersion: v1 kind: Pod metadata:name: random-number spec:containers:- image: alpinename: alpinecommand: ["/bin/sh","-c"]args: ["shuf -i 0-100 -n 1 >> /opt/number.out;"]volumeMounts:- mountPath: /optname: da…

Ubuntu findfont: Font family ‘SimHei‘ not found.

matplotlib中文乱码显示 当我们遇到这样奇怪的问题时, 结果往往很搞笑 尝试1不行 Stopping Jupyter Installing font-manager: sudo apt install font-manager Cleaning the matplotlib cache directory: rm ~/.cache/matplotlib -fr Restarting Jupyter. 尝试2 This work fo…

AI大模型开发架构设计(6)——AIGC时代,如何求职、转型与选择?

文章目录 AIGC时代&#xff0c;如何求职、转型与选择&#xff1f;1 新职场&#xff0c;普通人最值钱的能力是什么?2 新职场成长的3点建议第1点&#xff1a;目标感第2点&#xff1a;执行力第3点&#xff1a;高效生产力 3 新职场会产生哪些新岗位机会?如何借势?4 新职场普通人…

微信小程序(十七)自定义组件生命周期(根据状态栏自适配)

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.获取手机状态栏的高度 2.验证attached可以修改数据 3.动态绑定样式数值 源码&#xff1a; myNav.js Component({lifetimes:{//相当于vue的created,因为无法更新数据被打入冷宫created(){},//相当于vue的mount…

【运行Python爬虫脚本示例】

主要内容&#xff1a;Python中的两个库的使用。 1、requests库&#xff1a;访问和获取网页内容&#xff0c; 2、beautifulsoup4库&#xff1a;解析网页内容。 一 python 爬取数据 1 使用requests库发送GET请求&#xff0c;并使用text属性获取网页内容。 然后可以对获取的网页…

基于python flask茶叶网站数据大屏设计与实现,可以做期末课程设计或者毕业设计

基于Python的茶叶网站数据大屏设计与实现是一个适合期末课程设计或毕业设计的项目。该项目旨在利用Python技术和数据可视化方法&#xff0c;设计和开发一个针对茶叶行业的数据大屏&#xff0c;用于展示和分析茶叶网站的相关数据。 项目背景 随着互联网的快速发展&#xff0c;越…

一张图区分Spring Task的3种模式

是的&#xff0c;只有一张图&#xff1a; fixedDelay 模式cron 模式fixedRate 模式

[HUBUCTF 2022 新生赛]ezPython

打开是一个pyc文件&#xff0c;我用的是pycdc进行反编译 # Source Generated with Decompyle # File: ezPython.pyc (Python 3.7)from Crypto.Util.number import * import base64 import base58 password open(password.txt, r).read() tmp bytes_to_long(password.encode(u…

基于JavaWeb开发的家具定制购买网站【附源码】

基于JavaWeb开发的家具定制购买网站【附源码】 &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; &#x1f345; 查看下方微信号获取联系方式 承接各种定制系统 &#…

GLog开源库使用

Glog地址&#xff1a;https://github.com/google/glog 官方文档&#xff1a;http://google-glog.googlecode.com/svn/trunk/doc/glog.html 1.利用CMake进行编译&#xff0c;生成VS解决方案 &#xff08;1&#xff09;在glog-master文件夹内新建一个build文件夹&#xff0c;用…

docker-compose Install influxdb1+influxdb2+telegraf

influxd2前言 influxd2 是 InfluxDB 2.x 版本的后台进程,是一个开源的时序数据库平台,用于存储、查询和可视化时间序列数据。它提供了一个强大的查询语言和 API,可以快速而轻松地处理大量的高性能时序数据。 telegraf 是一个开源的代理程序,它可以收集、处理和传输各种不…

在Ubuntu上安装pycuda记录

1. 安装CUDA Toolkit 11.8 从MZ小师妹的摸索过程来看&#xff0c;其他版本的会有bug&#xff0c;12.0的版本太高&#xff0c;11.5的太低&#xff08;感谢小师妹让我少走弯路&#xff09; 参考网址&#xff1a;CUDA Toolkit 11.8 Downloads | NVIDIA Developer 在命令行输入命…

调用阿里通义千问大语言模型API-小白新手教程-python

阿里大语言模型通义千问API使用新手教程 最近需要用到大模型&#xff0c;了解到目前国产大模型中&#xff0c;阿里的通义千问有比较详细的SDK文档可进行二次开发,目前通义千问的API文档其实是可以进行精简然后学习的,也就是说&#xff0c;是可以通过简单的API调用在自己网页或…

Vue<圆形旋转菜单栏效果>

效果图&#xff1a; 大家不一定非要制成菜单栏&#xff0c;可以看下人家的华丽效果&#x1f61d;&#xff0c;参考地址 https://travelshift.com/ 大佬写的效果可比我的强多了&#xff0c;但是无从下手&#xff0c;所以就自己琢磨怎么写了&#xff0c;只能说效果勉强差不多 可…

“steam教学理念”scratch+数学 ——时钟案例

一、时钟概念 它通常由一个圆形表盘组成&#xff0c;表盘上有12个数字&#xff0c;分别是1到12。这些数字代表了小时。在表盘上&#xff0c;还有三根指针&#xff0c;一根较短的指针叫做时针&#xff0c;另一根较长的指针叫做分针&#xff0c;而秒针通常为红色&#xff0c;且指…

LabVIEW电缆检修系统

在电力系统中&#xff0c;合理选择电缆检修策略是保障电网稳定运行的关键。现有的电缆检修策略往往忽视了电缆的技术和经济双重指标&#xff0c;导致检修效率低下和维护成本高昂。为此&#xff0c;开发了一种基于风险评估模型和全寿命周期成本&#xff08;LCC&#xff09;的电缆…

java金额数字转中文

java金额数字转中文 运行结果&#xff1a; 会进行金额的四舍五入。 工具类源代码&#xff1a; /*** 金额数字转为中文*/ public class NumberToCN {/*** 汉语中数字大写*/private static final String[] CN_UPPER_NUMBER {"零", "壹", "贰",…

Springboot+Netty搭建基于TCP协议的服务端

文章目录 概要pom依赖Netty的server服务端类Netty通道初始化I/O数据读写处理测试发送消息 并 接收服务端回复异步启动Netty运行截图 概要 Netty是业界最流行的nio框架之一&#xff0c;它具有功能强大、性能优异、可定制性和可扩展性的优点 Netty的优点&#xff1a; 1.API使用简…

高中数学常识

一、大小关系 |x| > |sinx| 理由&#xff1a; 很明显&#xff0c;在圆内&#xff0c;弧长x>垂线sinx 3x、2x 、 1 2 \frac{1}{2} 21​x 理由&#xff1a; log 1 2 _\frac{1}{2} 21​​x、log 2 _2 2​x、 log 3 _3 3​x 二、(xy)? 的求法 利用二项式定理 三、平…