sklearn中多分类和多标签分类评估方法总结

一、任务区分

  1. 多分类分类任务:在多分类任务中,每个样本只能被分配到一个类别中。换句话说,每个样本只有一个正确的标签。例如,将图像分为不同的物体类别,如猫、狗、汽车等。

  2. 多标签分类任务:在多标签分类任务中,每个样本可以被分配到一个或多个类别中。换句话说,每个样本可以有多个正确的标签。例如,在图像标注任务中,一张图像可能同时包含猫和狗,因此它可以同时被分配到 "猫" 和 "狗" 这两个标签。

二、sklearn评估方式

 1、多分类(multiclass)任务

多分类任务的标签有两种,一种是原始标签,例如[0,1,2],另一种是独热编码的形式,[[1,0,0],[0,1,0],[0,0,1]]

经过模型分类之后的结果一般是各类的预测分数

(1)准确率(Accuracy):是分类正确的样本数与总样本数之比,是最简单的评估方法,但在类别不平衡的情况下可能会失效。

(2)混淆矩阵(Confusion Matrix):是一个 N×N 的矩阵(N 为类别数量),将真实类别与预测类别的对应关系表示出来。基于混淆矩阵可以计算精确率、召回率、F1 分数等指标。

(3)精确率(Precision)召回率(Recall):精确率表示被分类器正确分类的正样本数量与分类器预测为正样本的样本数量之比;召回率表示被分类器正确分类的正样本数量与数据集中所有正样本数量之比。

(4)F1 分数:精确率和召回率的调和平均数,综合考虑了分类器的准确性和完整性。

(5)ROC 曲线和AUC(Area Under the Curve):对于二分类任务,可以绘制ROC曲线,以真正例率(True Positive Rate)作为纵轴,假正例率(False Positive Rate)作为横轴。AUC表示ROC曲线下的面积,是一个评估分类器性能的常用指标。对于多分类任务,通常使用微平均(micro-average)或宏平均(macro-average)来计算AUC。

- 引入模块,并自己定义一下模型输出

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import numpy as np
import matplotlib.pyplot as plt
import torch

# 示例真实标签和预测结果
true_labels = np.array([0, 1, 2, 1, 0, 2, 2, 1, 0, 1])
print("true label",true_labels)
# 生成随机数据作为概率值,实际应用中需要替换为模型的预测概率值
model_output = torch.randn(len(true_labels), 3)
print("model output",model_output)
# 获得最大类别的index
_, predicted_labels = torch.max(model_output, 1)
print("predicted label",predicted_labels)

 示例数据如下:

- 进行模型评估

注意,计算roc_auc时需要将输出概率归一化,否则会报错

ValueError: Target scores need to be probabilities for multiclass roc_auc, i.e. they should sum up to 1.0 over classes

 准确率等的计算用的是模型输出后最大类别的index,而计算roc_auc直接用模型输出的概率,但需要归一化。

# 准确率
accuracy = accuracy_score(true_labels, predicted_labels)
print("Accuracy:", accuracy)
# 混淆矩阵
conf_matrix = confusion_matrix(true_labels, predicted_labels)
print("Confusion Matrix:\n", conf_matrix)
# 分类报告
class_report = classification_report(true_labels, predicted_labels)
print("Classification Report:\n", class_report)
# 精确率
precision = precision_score(true_labels, predicted_labels, average='macro')
print("Precision:", precision)
# 召回率
recall = recall_score(true_labels, predicted_labels, average='macro')
print("Recall:", recall)
# F1 分数
f1 = f1_score(true_labels, predicted_labels, average='macro')
print("F1 Score:", f1)
# ROC AUC
# 计算ROC需要将模型输出概率归一化
prob_new = torch.nn.functional.softmax(model_output, dim=1)
print(prob_new)
roc_auc = roc_auc_score(true_labels, prob_new, average='macro', multi_class='ovo')
print("ROC AUC Score:", roc_auc)

结果:

 

- 独热编码

但如果是对原始的标签数据进行了独热编码,那么在进行准确率等的计算的时候,需要将输出也转化为与独热编码类似的形式,然后再使用sklearn的函数进行计算

from sklearn.preprocessing import label_binarize
# 进行独热编码
true_one_hot = label_binarize(true_labels, classes=np.arange(3))
# 获取每行最大值的索引
max_indices = torch.argmax(model_output, dim=1)
# 创建一个与模型输出相同形状的零张量
predicted_labels = torch.zeros_like(model_output)
# 将每行最大值的位置设为1
predicted_labels[torch.arange(len(max_indices)), max_indices] = 1
print("predicted labels",predicted_labels)

accuracy = accuracy_score(true_one_hot, predicted_labels)
print("Accuracy:", accuracy)
roc_auc = roc_auc_score(true_one_hot, model_output, average='macro')
print("ROC AUC Score:", roc_auc)

结果如下:

总之,无论是采用原始标签的形式,还是独热编码的形式,在计算accuracy,recall,precision,F1-score的时候,都需要将模型输出转化为0,1且与真实标签维度一致的格式,而在计算roc的时候,若是独热编码的真实标签,则可以直接用模型输出,但如果不是,就需要归一化概率。

2、多标签(multilabel)分类任务

对于多标签分类,初始的真实标签要用到独热编码

 针对模型输出的概率分数,需要设定一个阈值,大于阈值的标记为1,低于阈值的标记为0,如下代码所示:

import torch
# 示例模型输出
model_output = torch.tensor([[0.8, 0.3, 0.9],
                             [0.2, 0.7, 0.4],
                             [0.9, 0.1, 0.3]])
# 设置阈值
threshold = 0.5
# 将概率分数转换为0-1结果
predicted_labels = (model_output > threshold).float()
print(predicted_labels)

然后直接使用sklearn的函数进行评估

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss, jaccard_score, coverage_error, average_precision_score, roc_auc_score
import numpy as np

# 示例标签和预测结果
true_labels = np.array([[1, 0, 1], [0, 1, 1], [1, 1, 0]])
predicted_labels = np.array([[1, 0, 1], [0, 1, 0], [1, 0, 0]])
# 准确率
accuracy = accuracy_score(true_labels, predicted_labels)
print("Accuracy:", accuracy)
# 精确率
precision = precision_score(true_labels, predicted_labels, average='micro')
print("Precision:", precision)
# 召回率
recall = recall_score(true_labels, predicted_labels, average='micro')
print("Recall:", recall)
# F1 分数
f1 = f1_score(true_labels, predicted_labels, average='micro')
print("F1 Score:", f1)
# 平均准确率
average_precision = average_precision_score(true_labels, predicted_labels, average='micro')
print("Average Precision:", average_precision)
# ROC AUC
roc_auc = roc_auc_score(true_labels, predicted_labels, average='micro')
print("ROC AUC Score:", roc_auc)

得到的结果如下:

详细对于多标签分类的指标解释可以参考下面的文章:

sklearn中多标签分类场景下的常见的模型评估指标_51CTO博客_sklearn模型评估icon-default.png?t=N7T8https://blog.51cto.com/liguodong/4290183总的来说,要根据自己的任务和目标来制定合适的评估指标,因为评估指标是实验结果的体现。


都看到了这里了,给个小心心♥呗~ 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/624771.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

抖音小店无货源怎么做?超详细教程,新手看这一篇就够了!

哈喽~我是电商月月 实体店做生意开店或是卖菜摆地摊,商家大部分都不是厂家自销,基本都是从批发厂里进货,然后加价卖出去的 “中间商赚差价”的行为在网上做店也是一种合理的行为 抖音小店里面的商家大部分选择的都是无货源模式运营 无货源…

Jmeter接口测试之参数化

在接口测试中,某些时候一些场景会使用到参数化的场景,参数化简单的说就是同一个请求需要不同的数据,比如在性能测试中需要并发多个用户的场景,这样的目的是为了模拟真实的用户场景,需要模拟不同的账号,这里…

2. C++入门:缺省参数及函数重载

缺省参数 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。在调用该函数时&#xff0c;如果没有指定实参则采用该形参的缺省值&#xff0c;否则使用指定的实参。 void Func(int a 0) {cout << a << endl; }int main() {Func();Func(10);return 0; }在形…

【simulink】Scrambling 加扰

https://ww2.mathworks.cn/help/comm/ug/additive-scrambling-of-input-data-in-simulink.html 草图 simulink 代码图

GPT-4o正式发布;零一万物发布千亿参数模型;英国推出AI评估平台

OpenAI 正式发布 GPT-4o 今天凌晨&#xff0c;OpenAI 正式发布 GPT-4o&#xff0c;其中的「o」代表「omni」&#xff08;即全面、全能的意思&#xff09;&#xff0c;这个模型同时具备文本、图片、视频和语音方面的能力&#xff0c;甚至就是 GPT-5 的一个未完成版。 并且&…

mysql 一次删除多个备份表

show tables时&#xff0c;发现备份的表有点多&#xff0c;想要一个sql就删除 总不能drop table xx ; 写多次吧。 方式一 1.生成删除某个数据库下所有的表SQL -- 查询构建批量删除表语句&#xff08;根据数据库名称&#xff09; select concat(drop table , TABLE_NAME, ;)…

Axure10_win安装教程(安装、汉化、授权码,去弹窗)

1.下载Axure10 链接&#xff1a;https://pan.baidu.com/s/1fc8Bgyic8Ct__1IOv-afUg 提取码&#xff1a;9qew 2.安装Axure10 因为我的电脑是Windows操作系统&#xff0c;所以我下载的AxureRP-Setup-Beta v10.0.0.3816 (Win).exe 一直点下一步就行 3.Axure10中文 打开Axure…

fdatool中的幅值响应 怎么计算

设计一个幅值是100&#xff0c;45hz的正弦波&#xff0c;设计一个滤波器&#xff0c;观察滤波器的幅值响应&#xff1a; 滤波前的数据的峰峰值是100*2&#xff1b; 假设滤波后的数据峰峰值变成95*2&#xff1b;&#xff08;95*2是滤波后的数据的峰峰值&#xff0c;不是fft后的值…

豆浆机缺水检测功能如何实现的

豆浆机缺水检测功能的实现是通过光学液位传感器来完成的。这种传感器具有多种优势&#xff0c;如内部所有元器件经过树脂胶封处理&#xff0c;没有任何机械活动部件&#xff0c;免调试、免检验、免维护等特点。它采用了光电液位传感器内置的光学电子元件&#xff0c;体积小、功…

自动化神器Autolt,让你不再重复工作!

随着互联网不断发展&#xff0c;它给我们带来便利的同时&#xff0c;也带来了枯燥、重复、机械的重复工作。今天&#xff0c;我要和大家分享一款老牌实用的自动化工具&#xff1a;AutoIt&#xff0c;它能够让你告别繁琐的重复性工作&#xff0c;提高工作效率。 这里透露一下&am…

python 虚拟环境多种创建方式

【一】说明介绍 &#xff08;1&#xff09;什么是虚拟环境 在Python中&#xff0c;虚拟环境&#xff08;Virtual Environment&#xff09;是一个独立的、隔离的Python运行环境&#xff0c;它拥有自己的Python解释器、第三方库和应用程序。通过创建虚拟环境&#xff0c;可以确…

解析流中 apts 与 vpts的分布

流中 apts 与 vpts的分布情况&#xff0c;同时使用图显示出来 一&#xff0c;最好的方式是使用EasyICE 来查看&#xff0c;这个自动化工具是很好用的&#xff1a; 二&#xff0c; 当EasyICE不能打出理想的数据的时候&#xff0c;可以自己来提取数据&#xff0c;画出对应的图&a…

老铁,对不住了,没有B端成品界面可售,都是定制化设计

经常有粉丝老铁问我有没有成品的UI图可以出售&#xff0c;实在对不住&#xff0c;我们不销售设计成品。 UI设计图是一种设计稿&#xff0c;它是用来展示和呈现产品的界面设计和交互效果的&#xff0c;而不是一个完整的、可交付的成品。UI设计图通常是以静态的形式呈现&#xf…

HTML常见标签-标题标签

标题标签 标题标签一般用于在页面上定义一些标题性的内容,如新闻标题,文章标题等,有h1到h6六级标题 代码 <body><h1>一级标题</h1><h2>二级标题</h2><h3>三级标题</h3><h4>四级标题</h4><h5>五级标题</h5>…

没有申请域名的情况下,用navicat远程连接我们的服务器的Mysql数据库

我们可以根据公网ip用shell来远程连接 首先我们打开自己买的服务器 例如你看这个&#xff0c;就是我们的公网IP 如果服务器里面没有安装mysql数据库的话&#xff0c;那么我们可以用一个轻量级的docker来安装数据库代替一下 我们用docker弄个轻量级的mysql5.7.36&#xff0c;…

【SRC实战】小游戏漏洞修改分数打榜

挖个洞先 https://mp.weixin.qq.com/s/Um0HU2srvZ0UlZRAsbSVug “ 以下漏洞均为实验靶场&#xff0c;如有雷同&#xff0c;纯属巧合 ” 01 — 漏洞证明 “ 如何刷分提高排名&#xff1f;” 1、进入小游戏&#xff0c;类似于跳一跳 2、开始时每次加1分 3、随着游戏进行…

react18【系列实用教程】react-router-dom —— 路由管理 (2024最新版)

类似 vue-router 安装 npm i react-router-domreact-router 中包含 native 的开发&#xff0c;仅网站开发&#xff0c;使用更轻量的 react-router-dom 即可 路由模式 history 模式需要后端支持&#xff0c;使用 createBrowserRouter 函数实现hash 模式无需后端支持&#xff0c;…

一键批量合并视频:掌握视频剪辑技巧解析,轻松创作完美影片

在数字时代的浪潮下&#xff0c;视频已成为人们记录和分享生活的重要工具。然而&#xff0c;对于许多非专业视频编辑者来说&#xff0c;将多个视频片段合并成一个完整的影片却是一项复杂且耗时的任务。幸运的是&#xff0c;云炫AI智剪一键批量合并视频功能的出现&#xff0c;让…

C/C++:Windows动态链接库

动态链接库&#xff08;Dynamic Link Library&#xff0c;简称DLL&#xff09;是在运行时加载的库&#xff0c;它们的代码和数据在内存中与目标程序共享&#xff1b;这意味着多个程序可以共享相同的库实例&#xff0c;并且库的代码可以在不重新编译目标程序的情况下更新。 工作…

【C#】WPF加载浏览器

结果展示 下载SDK 前端代码 红色框住的为添加代码 <Window x:Class"WPFwebview.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://…