模型性能的主要指标

主要参数

ROC 曲线和混淆矩阵都是用来评估分类模型性能的工具

ROC曲线(Receiver Operating Characteristic curve)

  • ROC曲线描述了当阈值变化时,真正类率(True Positive Rate, TPR)和假正类率(False Positive Rate, FPR)之间的关系。
  • TPR(也称为灵敏度或召回率)是正实例中正确预测为正的比例。
  • FPR是负实例中错误预测为正的比例。
  • ROC曲线下的面积(AUC-ROC)可以量化分类器的整体性能,值介于0和1之间。一个完美的分类器的AUC-ROC为1,而随机猜测的AUC-ROC为0.5。
  • ROC曲线主要用于评估二元分类问题,但也可以扩展到多分类场景。

混淆矩阵(Confusion Matrix)

  • 混淆矩阵是一个表格,用于显示实际类别与预测类别之间的对应关系。
  • 对于二元分类问题,混淆矩阵有四个组成部分:真正(True Positive, TP),假正(False Positive, FP),真负(True Negative, TN)和假负(False Negative, FN)。
    • TP: 正确预测的正样本数。
    • FP: 错误预测的正样本数。
    • TN: 正确预测的负样本数。
    • FN: 错误预测的负样本数。
  • 除了基本的四个值,还可以计算其他的性能指标,如精确度、召回率、F1分数等。
  • 对于多分类问题,混淆矩阵将扩展为一个更大的方阵,其中每一行代表一个实际类别,每一列代表一个预测类别。

ROC曲线特别适用于比较不同模型或算法的性能,而混淆矩阵可以为我们提供更详细的分类结果信息。

代码实现

ROC曲线

ROC曲线是用于评估二元分类模型效果的工具。Scikit-learn提供了方便的工具来生成ROC曲线。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import torch

# 假设您已经得到了模型的预测概率和真实标签
# predictions: [N, 2]  N是样本数,2是类别数
# labels: [N]
predictions = torch.rand((100, 2))
labels = torch.randint(0, 2, (100,))

# 获取正类的预测概率
probs = predictions[:, 1]

# 计算ROC
fpr, tpr, thresholds = roc_curve(labels, probs)

roc_auc = auc(fpr, tpr)

# 绘图
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.show()

 

混淆矩阵

混淆矩阵是一个简单的工具,它可以显示真实标签和预测标签之间的对应关系。

from sklearn.metrics import confusion_matrix
import seaborn as sns

# 根据概率获取预测的标签
_, predicted_labels = predictions.max(1)

cm = confusion_matrix(labels, predicted_labels)

# 绘图
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt=".0f", linewidths=.5, square = True, cmap = 'Blues_r');
plt.ylabel('Actual label');
plt.xlabel('Predicted label');
plt.title('Confusion Matrix');
plt.show()

 

注意:需要先安装matplotlib, seabornscikit-learn。可以使用pip来安装这些库。

 补充:其他性能指标

了ROC曲线和混淆矩阵外,分类模型常用的性能指标还有准确率、精确度、召回率、F1分数等。这些指标都可以基于混淆矩阵的四个基本值(TP, FP, TN, FN)来计算。以下是这些指标的定义和如何使用Python代码实现它们:

准确率(Accuracy):所有预测正确的样本与总样本的比例。

def accuracy(TP, FP, FN, TN):
    return (TP + TN) / (TP + FP + FN + TN)

精确度(Precision):正确预测为正的样本与所有预测为正的样本的比例。 

def precision(TP, FP):
    return TP / (TP + FP)

 召回率(Recall)或灵敏度(Sensitivity):正确预测为正的样本与所有实际为正的样本的比例。

def recall(TP, FN):
    return TP / (TP + FN)

 F1分数(F1 Score):精确度和召回率的调和平均值。

def f1_score(precision_value, recall_value):
    return 2 * (precision_value * recall_value) / (precision_value + recall_value)

 注意,上述公式和代码片段是为二元分类设计的。多分类问题需要稍作调整或使用micro、macro等不同的计算方式。

此外,我们可以直接调用scikit-learn库为我们提供了方便的函数来计算这些指标:

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

y_true = [1, 0, 1, 1, 0, 1]
y_pred = [1, 0, 1, 0, 0, 1]

print("Accuracy:", accuracy_score(y_true, y_pred))
print("Precision:", precision_score(y_true, y_pred))
print("Recall:", recall_score(y_true, y_pred))
print("F1 Score:", f1_score(y_true, y_pred))

“IOU”是Intersection over Union的缩写,它是一种测量对象检测算法预测边界框与实际边界框之间重叠程度的指标。IOU常用于计算目标检测和语义分割任务的准确性。

IOU定义为两个边界框交集的面积与它们并集的面积之比:

“mAP曲线”是基于不同IOU阈值的平均精确度。如果计算IoU并将其用于评估,以下是一个简单的Python代码示例,演示如何计算两个矩形的IoU:

 

def bb_intersection_over_union(boxA, boxB):
    # 确定交集矩形的(x, y)坐标
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    # 计算交集矩形的面积
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)

    # 计算预测矩形和实际矩形的面积
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    # 通过取交集面积并除以预测和实际矩形的面积之和(减去交集面积)来计算并集
    iou = interArea / float(boxAArea + boxBArea - interArea)

    # 返回并集值
    return iou

# 使用示例:
boxA = [10, 10, 50, 50]  # 格式: [x1, y1, x2, y2]
boxB = [20, 20, 60, 60]
print(bb_intersection_over_union(boxA, boxB)) # 应该打印一个0到1之间的值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/71621.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android Studio跳过Haxm打开模拟器

由于公司权限限制无法安装Haxm,这个时候我们可以试试Arm相关的镜像去跳过Haxm运行模拟器。解决方案:安装API27以下的Arm Image. #ifdef __x86_64__if (sarch "arm64" && apiLevel >28) {APANIC("Avds CPU Architecture %s i…

NPM与外部服务的集成(上)

目录 1、关于访问令牌 1.1 关于传统令牌 1.2 关于粒度访问令牌 2、创建和查看访问令牌 2.1 创建访问令牌 在网站上创建传统令牌 在网站上创建粒度访问令牌 使用CLI创建令牌 CIDR限制令牌错误 查看访问令牌 在网站上查看令牌 在CLI上查看令牌 令牌属性 1、关于访问令…

mysql数据库第十二课------mysql语句的拔高2------飞高高

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…

【C++】开源:CGAL计算几何库配置使用

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍CGAL计算几何库配置使用。 无专精则不能成,无涉猎则不能通。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关注一下,…

【java】mybatis-plus代码生成

正常的代码生成这里就不介绍了。旨在记录实现如下功能: 分布式微服务环境下,生成的entity、dto、vo、feignClient等等api模块,需要和mapper、service、controller等等分在不同的目录生成。 为什么会出现这个需求? mybatis-plus&am…

【计算机视觉|生成对抗】用深度卷积生成对抗网络进行无监督表示学习(DCGAN)

本系列博文为深度学习/计算机视觉论文笔记,转载请注明出处 标题:Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks 链接:[1511.06434] Unsupervised Representation Learning with Deep Conv…

微信小程序中键盘弹起输入框自动跳到键盘上方处理

效果展示 键盘未弹起时 键盘弹起后: 实现方式 话就不多说了 我直接贴代码了 原理就是用你点击的输入框的底部 距离顶部的位置 减去屏幕高度除以2,然后设成负值,再将这个值给到最外层相对定位的盒子的top属性,这样就不会出现顶…

服务器安装JDK

三种方法 方法一: 方法二: 首先登录到Oracle官网下载JDK JDK上传到服务器中,记住文件上传的位置是在哪里(我放的位置在/www/java),然后看下面指示进行安装 方法三: 首先登录到Oracle官网下载…

修改IDEA的idea.vmoptions参数导致IDEA无法打开(ReservedCodeCacheSize)

事发原因 Maven导依赖的时候OOM,因此怀疑是内存太小,尝试修改idea.vmoptions的参数,然后发现IDEA重启后打不开了,卸载重装后也无法打开。。。 实际上如果导包爆出OOM的话应该调整下图参数,不过这都是后话了 解决思路…

王道机组难题分析

第四章 指令系统 大端方式:就是高地址存放高位, LSB的意思是:全称为Least Significant Bit,在二进制数中意为最低有效位 MSB的意思是:全称为Most Significant Bit,在二进制数中属于最高有效位 操作数可以理…

HCIP学习--BGP2

目录 前置内容 BGP宣告问题 BGP自动汇总问题 BGP 的认证 BGP的聚合(汇总) 标准的BGP聚合配置 非标准的BGP聚合 路由传递干涉策略 抑制列表 Route-map 分发列表 前缀列表 BGP在MA网络中下一跳问题-ICMP重定向 查看与某个邻居收发的路由 配置 有条件打破IBGP水平…

腾讯云轻量应用服务器镜像应用模板清单大全

腾讯云轻量应用服务器支持多种应用模板镜像,Windows和Linux镜像模板都有,如:宝塔Linux面板腾讯云专享版、WordPress、WooCommerce、LAMP、Node.js、Docker CE、K3s、宝塔Windows面板和ASP.NET等应用模板镜像,腾讯云服务器网分享腾…

04 - 分离头指针情况、理解HEAD和branch

查看所有文章链接:(更新中)GIT常用场景- 目录 文章目录 1. 分离头指针2. HEAD和branch2.1 branch的一些操作2.2 HEAD 1. 分离头指针 分离头指针detached HEAD是一种HEAD指针指向了某一个具体的 commit id,而不是分支的情况。 切换…

HTTP请求性能分析 - 简单

使用随手可得的工具,尽量少的前置要求,来完成任务。 0. 目录 1. 前言2. 分析工具2.1 基于Chrome DevTools 的Timing2.1.1 关于Network标签页下的Timing部分2.1.2 一些注意项 2.2 基于Curl 命令 3. 剩下的工作 1. 前言 对于业务开发选手而言,…

Vision Transformer模型入门

Vision Transformer模型入门 一、Vision Transformer 模型1,Embedding 层结构详解2,Transformer Encoder 详解3,MLP Head 详解 二、ViT-B/16 网络结构三、Hybrid 模型详解四、ViT 模型搭建参数 一、Vision Transformer 模型 总体三个模块&am…

无涯教程-Perl - getprotobynumber函数

描述 此函数在标量context中将协议NUMBER转换为其对应的名称,在列表context中将其名称和相关信息转换为:($name,$aliases,$protocol_number)。 语法 以下是此函数的简单语法- getprotobynumber NUMBER返回值 此函数针对错误返回undef,否则返回标量context中的协议编号,并在…

HBase API

我们之后的实际开发中不可能在服务器那边直接使用shell命令一直敲的&#xff0c;一般都是通过API进行操作的。 环境准备 新建Maven项目&#xff0c;导入Maven依赖 <dependencies><dependency><groupId>org.apache.hbase</groupId><artifactId>…

Java中创建线程三种方式

继承Thread类创建线程实现Runnable接口创建线程使用Callable和Future创建线程 继承Thread类 /*** 使用集成Thread的方式实现多线程*/ public class Match1 {public static void main(String[] args) {Runner liuxiang new Runner();//创建一个新的线程liuxiang.setName(&quo…

当执行汇编指令MOV [0001H] 01H时,CPU都做了什么?

今天和几位单位大佬聊天时&#xff0c;讨论到一个非常有趣的问题-当程序执行MOV [0001H], 01H计算机实际上都做了哪些工作&#xff1f;乍一看这个问题平平无奇&#xff0c;CPU只是把立即数01H放在了地址为0001的内存里&#xff0c;但仔细想想这个问题远没有那么简单&#xff0c…

matlab解常微分方程常用数值解法1:前向欧拉法和改进的欧拉法

总结和记录一下matlab求解常微分方程常用的数值解法&#xff0c;本文先从欧拉法和改进的欧拉法讲起。 d x d t f ( x , t ) , x ( t 0 ) x 0 \frac{d x}{d t}f(x, t), \quad x\left(t_{0}\right)x_{0} dtdx​f(x,t),x(t0​)x0​ 1. 前向欧拉法 前向欧拉法使用了泰勒展开的第…