【深度学习】多分类任务评估指标sklearn和torchmetrics对比

【深度学习】多分类任务评估指标sklearn和torchmetrics对比

  • 说明
  • sklearn代码
  • torchmetrics代码
  • 两个MultiClassReport类的对比分析
    • 1. 代码结构与实现方式
    • 2. 数据处理与内存使用
    • 3. 性能与效率
  • 二分类任务评估指标
    • 1. 准确率(Accuracy)
    • 2. 精确率(Precision)
    • 3. 召回率(Recall)
    • 4. F1值(F1-score)
  • 多分类评估指标
    • 1. 混淆矩阵(Confusion Matrix)
    • 2. 准确率(Accuracy)
    • 3. 精确率(Precision)
    • 4. 召回率(Recall)
    • 5. F1值(宏平均)

说明

sklearn和torchmetrics两个metric代码跑模型的输出结果一致,对比他们的区别。评估指标写在下面

sklearn代码

import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

class MultiClassReport():
    """
    Accuracy, F1 Score, Precision and Recall for multi - class classification task.
    """

    def __init__(self, name='MultiClassReport', average='macro'):
        super(MultiClassReport, self).__init__()
        self.average = average
        self._name = name
        self.reset()

    def reset(self):
        """
        Resets all the metric state.
        """
        self.y_prob = []
        self.y_true = []

    def update(self, probs, labels):
        # 将Tensor转换为numpy数组并添加到相应列表中
        if isinstance(probs, torch.Tensor):
            if probs.requires_grad:
                probs = probs.detach()
            probs = probs.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            if labels.requires_grad:
                labels = labels.detach()
            labels = labels.cpu().numpy()
        self.y_prob.extend(probs)
        self.y_true.extend(labels)
        self.y_prob.extend(probs)
        self.y_true.extend(labels)

    def accumulate(self):
        accuracy = accuracy_score(self.y_true, np.argmax(self.y_prob, axis=1))
        f1 = f1_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        precision = precision_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        recall = recall_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        return accuracy, f1, precision, recall

    def name(self):
        """
        Returns metric name
        """
        return self._name

torchmetrics代码

from torchmetrics import Accuracy, F1Score, Precision, Recall
from model import polarity_classes, device

# 创建评估指标对象
accuracy_metric = Accuracy(task='multiclass', num_classes=polarity_classes).to(device)
f1_metric = F1Score(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
precision_metric = Precision(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
recall_metric = Recall(task='multiclass', num_classes=polarity_classes, average='macro').to(device)

class MultiClassReport():
    """
    Accuracy, F1 Score, Precision and Recall for multi-class classification task.
    average:micro、macro
    """

    def __init__(self, name='MultiClassReport', average='macro'):
        super(MultiClassReport, self).__init__()
        self.average = average
        self._name = name

    def reset(self):
        """
        Resets all the metric state.
        """
        accuracy_metric.reset()
        f1_metric.reset()
        precision_metric.reset()
        recall_metric.reset()

    def update(self, probs, labels):
        accuracy_metric.update(probs, labels)
        f1_metric.update(probs, labels)
        precision_metric.update(probs, labels)
        recall_metric.update(probs, labels)

    def accumulate(self):
        accuracy = accuracy_metric.compute()
        f1 = f1_metric.compute()
        precision = precision_metric.compute()
        recall = recall_metric.compute()
        return accuracy, f1, precision, recall

    def name(self):
        """
        Returns metric name
        """
        return self._name

两个MultiClassReport类的对比分析

1. 代码结构与实现方式

  • sklearn版本
    • 代码逻辑较为清晰直接。在update方法中,将输入的PyTorch张量转换为numpy数组,并存储到y_proby_true列表中。在accumulate方法中,直接使用sklearnaccuracy_scoref1_scoreprecision_scorerecall_score函数基于存储的列表数据计算评估指标。
  • torchmetrics版本
    • 利用torchmetrics库提供的专门的评估指标类(AccuracyF1ScorePrecisionRecall)。在update方法中,直接调用这些类的update方法来处理输入数据,内部有自己的状态管理机制。在accumulate方法中,通过调用相应类的compute方法获取评估指标值。
    • 这种方式与PyTorch的生态系统集成得更好,尤其是在基于PyTorch进行深度学习项目开发时,可以方便地在GPU上进行计算(如果deviceGPU),并且可以利用torchmetrics库的其他特性,如分布式训练支持等。

2. 数据处理与内存使用

  • sklearn版本
    • update方法中不断扩展y_proby_true列表来存储数据。如果处理大量数据,可能会占用较多内存,因为它需要将所有的预测概率和真实标签都保存在内存中。
    • 每次计算评估指标时,都需要对整个存储的数组进行操作,如np.argmax等,这在数据量较大时可能会有一定的计算开销。
  • torchmetrics版本
    • 虽然torchmetrics类内部也需要存储一定的状态信息,但它们的设计可能更高效地利用内存和处理数据更新。例如,它们可能会采用增量计算的方式,而不是像sklearn版本那样一次性处理所有数据。
    • 在处理大规模数据或长时间训练过程中,torchmetrics版本可能在内存管理和计算效率方面更有优势。

3. 性能与效率

  • sklearn版本
    • 在小规模数据和简单场景下,性能表现良好。但随着数据量的增加和模型复杂度的提高,由于数据转换和计算方式的原因,可能会出现性能瓶颈。
  • torchmetrics版本
    • 设计初衷就是为了在PyTorch深度学习环境中高效运行,特别是在利用GPU计算资源时,能够更高效地更新和计算评估指标,更适合大规模数据和复杂模型的评估场景。

二分类任务评估指标

TP(True Positive)是真正例,TN(True Negative)是真反例,FP(False Positive)是假正例,FN(False Negative)是假反例。

1. 准确率(Accuracy)

准确率是指在所有预测样本中,预测正确的样本所占的比例。它衡量的是模型整体预测正确的程度。
在这里插入图片描述

2. 精确率(Precision)

精确率是指在所有被预测为正类的样本中,真正为正类的样本所占的比例。
![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/0c6f404c85df462ab1c4831f84cbd94d.png

3. 召回率(Recall)

召回率是指在所有实际为正类的样本中,被模型正确预测为正类的样本所占的比例。
在这里插入图片描述

4. F1值(F1-score)

F1值是精确率和召回率的调和平均数,它综合考虑了精确率和召回率两个指标,能够更全面地评估模型的性能。
在这里插入图片描述

多分类评估指标

1. 混淆矩阵(Confusion Matrix)

它是一个方阵,用来展示分类模型在每个类别上的预测对错情况。行代表真实类别,列代表预测类别,某个位置的值就是实际是某类却被预测成另一类的样本数量,能直观呈现模型对各类别预测的混淆情况。
在这里插入图片描述

2. 准确率(Accuracy)

就是模型预测正确的样本数占总样本数的比例,反映整体预测正确程度。
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

3. 精确率(Precision)

  • 类别精确率:对于每个类别,是预测为该类且正确的样本数除以预测为该类的样本数,看预测某类时的准确程度。
    在这里插入图片描述

  • 宏平均精确率(Macro-average Precision):先算出每个类别的精确率,再求平均,平等看待每个类别。
    在这里插入图片描述

  • 微平均精确率(Micro-average Precision):将所有类别预测对的情况汇总除以预测的总数,从整体上看预测的精准情况,对类别不平衡不太敏感。
    在这里插入图片描述

4. 召回率(Recall)

  • 类别召回率:对于每个类别,是预测为该类且正确的样本数除以实际是该类的样本数,体现对该类样本的召回能力。
    在这里插入图片描述

  • 宏平均召回率(Macro- average Recall):先算出每个类别的召回率,再求平均,衡量对每个类别样本的召回水平。
    在这里插入图片描述

  • 微平均召回率(Micro-average Recall):从整体角度,用所有类别预测正确的样本总数除以实际各类别样本总数,综合评估召回情况。
    在这里插入图片描述

5. F1值(宏平均)

  • 类别F1值:是类别精确率和召回率的调和平均数,综合二者信息。
    在这里插入图片描述

  • 宏平均F1值(Macro-average F1):先计算每个类别的F1值,再平均,更全面地体现模型对各分类的整体性能。
    在这里插入图片描述

  • 微平均F1值(Micro-average F1):基于微平均精确率(Precision micro)和微平均召回率(Recall micro)来计算
    微平均
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/911264.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[CUDA] 设置sync模式cudaSetDeviceFlags

文章目录 1. 设置cuda synchronize的等待模式2 设置函数3. streamQuery方式实现stream sync等待逻辑Reference 1. 设置cuda synchronize的等待模式 参考资料:https://docs.nvidia.com/cuda/pdf/CUDA_Runtime_API.pdf cuda的 synchronize等待模式分为: Y…

jdk安装升级到jdk17

百度安全验证 有些项目编译不过 找不到类 ,实际有,需要升级jdk到17 https://blog.csdn.net/qq_44866828/article/details/130557027 sudo apt-get update sudo apt-get install openjdk-17-jdk 然后修改一下配置路径 也就是环境变量 11 改成17 重新…

cuda、pytorch-gpu安装踩坑!!!

前提:已经安装了acanoda cuda11.6下载 直接搜索cuda11.6 acanoda操作 python版本3.9 conda create -n pytorch python3.9conda activate pytorch安装Pytorch-gpu版本等包 要使用pip安装,cu116cuda11.6版本 pip install torch1.13.1cu116 torchvi…

H.265流媒体播放器EasyPlayer.js网页web无插件播放器:如何优化加载速度

在当今的网络环境中,用户对于视频播放体验的要求越来越高,尤其是对于视频加载速度的期待。EasyPlayer.js网页web无插件播放器作为一款专为现代Web环境设计的流媒体播放器,它在优化加载速度方面采取了多种措施,以确保用户能够享受到…

C语言 | Leetcode C语言题解之第542题01矩阵

题目: 题解: /*** Return an array of arrays of size *returnSize.* The sizes of the arrays are returned as *returnColumnSizes array.* Note: Both returned array and *columnSizes array must be malloced, assume caller calls free().*/ type…

Transformer究竟是什么?预训练又指什么?BERT

目录 Transformer究竟是什么? 预训练又指什么? BERT的影响力 Transformer究竟是什么? Transformer是一种基于自注意力机制(Self-Attention Mechanism)的神经网络架构,它最初是为解决机器翻译等序列到序列(Seq2Seq)任务而设计的。与传统的循环神经网络(RNN)或卷…

【春秋云镜】CVE-2023-23752

目录 CVE-2023-23752漏洞细节漏洞利用示例修复建议 春秋云镜:解法一:解法二: CVE-2023-23752 是一个影响 Joomla CMS 的未授权路径遍历漏洞。该漏洞出现在 Joomla 4.0.0 至 4.2.7 版本中,允许未经认证的远程攻击者通过特定 API 端…

51单片机教程(七)- 蜂鸣器

1 项目分析 利用P2.3引脚输出电平变化,控制蜂鸣器的鸣叫。 2 技术准备 1 蜂鸣器介绍 有绿色电路板的一种是无源蜂鸣器,没有电路板而用黑胶封闭的一种是有源蜂鸣器。 有源蜂鸣器和无源蜂鸣器 这里的“源”不是指电源。而是指震荡源。也就是说有源蜂鸣…

十六 MyBatis使用PageHelper

十六、MyBatis使用PageHelper 16.1 limit分页 mysql的limit后面两个数字: 第一个数字:startIndex(起始下标。下标从0开始。)第二个数字:pageSize(每页显示的记录条数) 假设已知页码pageNum&…

汽车和飞机研制过程中“骡车”和“铁鸟”

在汽车和飞机的研制过程中,“骡车”和“铁鸟”都扮演着至关重要的角色。 “骡车”在汽车研制中,是一种处于原型车和量产车之间的过渡阶段产物。它通常由不同的零部件组合而成,就像骡子是马和驴的杂交后代一样,取各家之长。“骡车…

MySQL存储目录与配置文件(ubunto下)

mysql的配置文件: 在这个目录下,直接cd /etc/mysql/mysql.conf.d mysql的储存目录: /var/lib/mysql Ubuntu版本号:

RibbitMQ-安装

本文主要介绍RibbitMQ的安装 RabbitMQ依赖于Erlang,因此首先需要安装Erlang环境。分别下载erlang-26.2.5-1.el7.x86_64.rpm、rabbitmq-server-4.0.3-1.el8.noarch.rpm 官网地址:https://www.rabbitmq.com/ 官网文档:https://www.rabbitmq.c…

【Linux】解锁操作系统潜能,高效线程管理的实战技巧

目录 1. 线程的概念2. 线程的理解3. 地址空间和页表4. 线程的控制4.1. POSIX线程库4.2 线程创建 — pthread_create4.3. 获取线程ID — pthread_self4.4. 线程终止4.5. 线程等待 — pthread_join4.6. 线程分离 — pthread_detach 5. 线程的特点5.1. 优点5.2. 缺点5.3. 线程异常…

WPF+MVVM案例实战(二十二)- 制作一个侧边弹窗栏(CD类)

文章目录 1、案例效果1、侧边栏分类2、CD类侧边弹窗实现1、样式代码实现2、功能代码实现3 运行效果4、源代码获取1、案例效果 1、侧边栏分类 A类 :左侧弹出侧边栏B类 :右侧弹出侧边栏C类 :顶部弹出侧边栏D类 :底部弹出侧边栏2、CD类侧边弹窗实现 1、样式代码实现 在原有的…

如何对LabVIEW软件进行性能评估?

对LabVIEW软件进行性能评估,可以从以下几个方面着手,通过定量与定性分析,全面了解软件在实际应用中的表现。这些评估方法适用于确保LabVIEW程序的运行效率、稳定性和可维护性。 一、响应时间和执行效率 时间戳测量:使用LabVIEW的时…

stm32使用串口DMA实现数据的收发

前言 DMA的作用就是帮助CPU来传输数据,从而使CPU去完成更重要的任务,不浪费CPU的时间。 一、配置stm32cubeMX 这两个全添加上。参数配置一般默认即可 代码部分 只需要把上期文章里的HAL_UART_Transmit_IT(&huart2,DATE,2); 全都改为HAL_UART_Tra…

论文1—《基于卷积神经网络的手术机器人控制系统设计》文献阅读分析报告

论文报告:基于卷积神经网络的手术机器人控制系统设计 摘要 本研究针对传统手术机器人控制系统精准度不足的问题,提出了一种基于卷积神经网络的手术机器人控制系统设计。研究设计了控制系统的总体结构,并选用PCI插槽上直接内插CAN适配卡作为上…

「C/C++」C/C++ 之 变量作用域详解

✨博客主页何曾参静谧的博客📌文章专栏「C/C」C/C程序设计📚全部专栏「VS」Visual Studio「C/C」C/C程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「PK」Parasoli…

JSP ft06 问题几个求解思路整理

刷到这篇文章使用Q-learning去求接JSP ft06 问题用基本Q-learning解决作业车间调度问题(JSP),以FT06案例为例_q-learning算法在车间调度-CSDN博客 本着贼不走空的原则打算全部copy到本地试下,文章作者使用的tf06.txt在这里获取 https://web.cecs.pdx.e…

Uniapp安装Pinia并持久化(Vue3)

安装pinia 在uni-app的Vue3版本中,Pinia已被内置,无需额外安装即可直接使用(Vue2版本则内置了Vuex)。 HBuilder X项目:直接使用,无需安装。CLI项目:需手动安装,执行yarn add pinia…